mirror of
https://github.com/langgenius/dify.git
synced 2026-05-27 04:16:16 +08:00
Compare commits
1 Commits
build/cli
...
codex/draw
| Author | SHA1 | Date | |
|---|---|---|---|
| ce18fc7d6c |
@ -4,7 +4,7 @@ from datetime import UTC, datetime
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
@ -17,17 +17,18 @@ from controllers.openapi._models import (
|
||||
SessionRow,
|
||||
WorkspacePayload,
|
||||
)
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.oauth_bearer import (
|
||||
Scope,
|
||||
TokenType,
|
||||
ACCEPT_USER_ANY,
|
||||
AuthContext,
|
||||
SubjectType,
|
||||
get_auth_ctx,
|
||||
validate_bearer,
|
||||
)
|
||||
from libs.rate_limit import (
|
||||
LIMIT_ME_PER_ACCOUNT,
|
||||
LIMIT_ME_PER_EMAIL,
|
||||
enforce,
|
||||
)
|
||||
from services.account_service import AccountService, TenantService
|
||||
@ -41,18 +42,32 @@ from services.oauth_device_flow import (
|
||||
@openapi_ns.route("/account")
|
||||
class AccountApi(Resource):
|
||||
@openapi_ns.response(200, "Account info", openapi_ns.models[AccountResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{auth_data.account_id}")
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
ctx = get_auth_ctx()
|
||||
|
||||
account_id_str = str(auth_data.account_id) if auth_data.account_id else None
|
||||
account = AccountService.get_account_by_id(db.session, account_id_str) if account_id_str else None
|
||||
memberships = TenantService.get_account_memberships(db.session, account_id_str) if account_id_str else []
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
enforce(LIMIT_ME_PER_EMAIL, key=f"subject:{ctx.subject_email}")
|
||||
else:
|
||||
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{ctx.account_id}")
|
||||
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
return AccountResponse(
|
||||
subject_type=ctx.subject_type,
|
||||
subject_email=ctx.subject_email,
|
||||
subject_issuer=ctx.subject_issuer,
|
||||
account=None,
|
||||
workspaces=[],
|
||||
default_workspace_id=None,
|
||||
).model_dump(mode="json")
|
||||
|
||||
account = AccountService.get_account_by_id(db.session, str(ctx.account_id)) if ctx.account_id else None
|
||||
memberships = TenantService.get_account_memberships(db.session, str(ctx.account_id)) if ctx.account_id else []
|
||||
default_ws_id = _pick_default_workspace(memberships)
|
||||
|
||||
return AccountResponse(
|
||||
subject_type="account",
|
||||
subject_email=account.email if account else None,
|
||||
subject_type=ctx.subject_type,
|
||||
subject_email=ctx.subject_email or (account.email if account else None),
|
||||
account=_account_payload(account) if account else None,
|
||||
workspaces=[_workspace_payload(m) for m in memberships],
|
||||
default_workspace_id=default_ws_id,
|
||||
@ -62,17 +77,19 @@ class AccountApi(Resource):
|
||||
@openapi_ns.route("/account/sessions/self")
|
||||
class AccountSessionsSelfApi(Resource):
|
||||
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def delete(self, *, auth_data: AuthData):
|
||||
revoke_oauth_token(db.session, redis_client, str(auth_data.token_id))
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def delete(self):
|
||||
ctx = get_auth_ctx()
|
||||
_require_oauth_subject(ctx)
|
||||
revoke_oauth_token(db.session, redis_client, str(ctx.token_id))
|
||||
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions")
|
||||
class AccountSessionsApi(Resource):
|
||||
@openapi_ns.response(200, "Session list", openapi_ns.models[SessionListResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
ctx = get_auth_ctx()
|
||||
now = datetime.now(UTC)
|
||||
page = int(request.args.get("page", "1"))
|
||||
@ -105,9 +122,10 @@ class AccountSessionsApi(Resource):
|
||||
@openapi_ns.route("/account/sessions/<string:session_id>")
|
||||
class AccountSessionByIdApi(Resource):
|
||||
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def delete(self, session_id: str, *, auth_data: AuthData):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def delete(self, session_id: str):
|
||||
ctx = get_auth_ctx()
|
||||
_require_oauth_subject(ctx)
|
||||
|
||||
# 404 (not 403) on cross-subject so the endpoint doesn't leak
|
||||
# token IDs that belong to other subjects.
|
||||
@ -118,6 +136,13 @@ class AccountSessionByIdApi(Resource):
|
||||
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
|
||||
|
||||
|
||||
def _require_oauth_subject(ctx: AuthContext) -> None:
|
||||
if not ctx.source.startswith("oauth"):
|
||||
raise BadRequest(
|
||||
"this endpoint revokes OAuth bearer tokens; use /openapi/v1/personal-access-tokens/self for PATs"
|
||||
)
|
||||
|
||||
|
||||
def _iso(dt: datetime | None) -> str | None:
|
||||
if dt is None:
|
||||
return None
|
||||
|
||||
@ -16,8 +16,7 @@ import services
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._audit import emit_app_run
|
||||
from controllers.openapi._models import AppRunRequest
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
@ -125,11 +124,8 @@ _DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = {
|
||||
class AppRunApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[AppRunRequest.__name__])
|
||||
@openapi_ns.response(200, "Run result (SSE stream)")
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, *, auth_data: AuthData):
|
||||
app_model = auth_data.app
|
||||
caller = auth_data.caller
|
||||
caller_kind = auth_data.caller_kind
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
payload = AppRunRequest.model_validate(body)
|
||||
@ -162,11 +158,8 @@ class AppRunApi(Resource):
|
||||
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/stop")
|
||||
class AppRunTaskStopApi(Resource):
|
||||
@openapi_ns.response(200, "Task stopped")
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, task_id: str, *, auth_data: AuthData):
|
||||
app_model = auth_data.app
|
||||
caller = auth_data.caller
|
||||
caller_kind = auth_data.caller_kind
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str):
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
return {"result": "success"}
|
||||
|
||||
@ -1,4 +1,9 @@
|
||||
"""GET /openapi/v1/apps and per-app reads."""
|
||||
"""GET /openapi/v1/apps and per-app reads.
|
||||
|
||||
Decorator order: `method_decorators` is innermost-first. `validate_bearer`
|
||||
is last → outermost → publishes the auth ContextVar before `require_scope`
|
||||
reads it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -23,17 +28,31 @@ from controllers.openapi._models import (
|
||||
AppListRow,
|
||||
TagItem,
|
||||
)
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
AuthContext,
|
||||
Scope,
|
||||
SubjectType,
|
||||
get_auth_ctx,
|
||||
require_scope,
|
||||
require_workspace_member,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppListParams, AppService
|
||||
from services.tag_service import TagService
|
||||
|
||||
_APPS_READ_DECORATORS = [
|
||||
require_scope(Scope.APPS_READ),
|
||||
accept_subjects(SubjectType.ACCOUNT),
|
||||
validate_bearer(accept=ACCEPT_USER_ANY),
|
||||
]
|
||||
|
||||
_ALLOWED_DESCRIBE_FIELDS: frozenset[str] = frozenset({"info", "parameters", "input_schema"})
|
||||
|
||||
|
||||
@ -47,9 +66,13 @@ _EMPTY_PARAMETERS: dict[str, Any] = {
|
||||
|
||||
|
||||
class AppReadResource(Resource):
|
||||
"""Base for per-app read endpoints; subclasses call `_load()` for membership/exists checks."""
|
||||
"""Base for per-app read endpoints; subclasses call `_load()` for SSO/membership/exists checks."""
|
||||
|
||||
method_decorators = _APPS_READ_DECORATORS
|
||||
|
||||
def _load(self, app_id: str, workspace_id: str | None = None) -> tuple[App, AuthContext]:
|
||||
ctx: AuthContext = get_auth_ctx()
|
||||
|
||||
def _load(self, app_id: str, workspace_id: str | None = None) -> App:
|
||||
try:
|
||||
parsed_uuid = _uuid.UUID(app_id)
|
||||
is_uuid = True
|
||||
@ -76,7 +99,8 @@ class AppReadResource(Resource):
|
||||
raise Conflict("".join(lines))
|
||||
app = matches[0]
|
||||
|
||||
return app
|
||||
require_workspace_member(ctx, str(app.tenant_id))
|
||||
return app, ctx
|
||||
|
||||
|
||||
def parameters_payload(app: App) -> dict:
|
||||
@ -90,14 +114,13 @@ def parameters_payload(app: App) -> dict:
|
||||
class AppDescribeApi(AppReadResource):
|
||||
@openapi_ns.doc(params=query_params_from_model(AppDescribeQuery))
|
||||
@openapi_ns.response(200, "App description", openapi_ns.models[AppDescribeResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, app_id: str, *, auth_data: AuthData):
|
||||
def get(self, app_id: str):
|
||||
try:
|
||||
query = AppDescribeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
app = self._load(app_id, workspace_id=query.workspace_id)
|
||||
app, _ = self._load(app_id, workspace_id=query.workspace_id)
|
||||
|
||||
requested = query.fields
|
||||
want_info = requested is None or "info" in requested
|
||||
@ -145,16 +168,20 @@ class AppDescribeApi(AppReadResource):
|
||||
|
||||
@openapi_ns.route("/apps")
|
||||
class AppListApi(Resource):
|
||||
method_decorators = _APPS_READ_DECORATORS
|
||||
|
||||
@openapi_ns.doc(params=query_params_from_model(AppListQuery))
|
||||
@openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
def get(self):
|
||||
ctx: AuthContext = get_auth_ctx()
|
||||
|
||||
try:
|
||||
query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
workspace_id = query.workspace_id
|
||||
require_workspace_member(ctx, workspace_id)
|
||||
|
||||
empty = (
|
||||
AppListResponse(page=query.page, limit=query.limit, total=0, has_more=False, data=[]).model_dump(
|
||||
@ -210,7 +237,7 @@ class AppListApi(Resource):
|
||||
openapi_visible=True,
|
||||
)
|
||||
|
||||
pagination = AppService().get_paginate_apps(str(auth_data.account_id), workspace_id, params)
|
||||
pagination = AppService().get_paginate_apps(str(ctx.account_id), workspace_id, params)
|
||||
if pagination is None:
|
||||
return empty
|
||||
|
||||
|
||||
@ -18,27 +18,37 @@ from controllers.openapi._models import (
|
||||
PermittedExternalAppsListQuery,
|
||||
PermittedExternalAppsListResponse,
|
||||
)
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData, Edition
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from libs.device_flow_security import enterprise_only
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
Scope,
|
||||
SubjectType,
|
||||
require_scope,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.app_permitted_service import list_permitted_apps
|
||||
from services.openapi.license_gate import license_required
|
||||
|
||||
|
||||
@openapi_ns.route("/permitted-external-apps")
|
||||
class PermittedExternalAppsListApi(Resource):
|
||||
method_decorators = [
|
||||
require_scope(Scope.APPS_READ_PERMITTED_EXTERNAL),
|
||||
license_required,
|
||||
accept_subjects(SubjectType.EXTERNAL_SSO),
|
||||
validate_bearer(accept=ACCEPT_USER_ANY),
|
||||
enterprise_only,
|
||||
]
|
||||
|
||||
@openapi_ns.response(
|
||||
200, "Permitted external apps list", openapi_ns.models[PermittedExternalAppsListResponse.__name__]
|
||||
)
|
||||
@auth_router.guard(
|
||||
scope=Scope.APPS_READ_PERMITTED_EXTERNAL,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_EXTERNAL_SSO}),
|
||||
edition=frozenset({Edition.EE}),
|
||||
)
|
||||
def get(self, *, auth_data: AuthData):
|
||||
def get(self):
|
||||
try:
|
||||
query = PermittedExternalAppsListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
|
||||
__all__ = ["auth_router"]
|
||||
__all__ = ["OAUTH_BEARER_PIPELINE"]
|
||||
|
||||
@ -1,64 +1,46 @@
|
||||
"""`OAUTH_BEARER_PIPELINE` — the auth scheme for openapi `/run` endpoints.
|
||||
|
||||
Endpoints attach via `@OAUTH_BEARER_PIPELINE.guard(scope=…)`. No alternative
|
||||
paths. Read endpoints (`/apps`, `/info`, `/parameters`, `/describe`) skip
|
||||
the pipeline and use `validate_bearer + require_scope + require_workspace_member`
|
||||
inline — they don't need `AppAuthzCheck`/`CallerMount`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from controllers.openapi.auth.conditions import (
|
||||
EDITION_CE,
|
||||
EDITION_EE,
|
||||
LOADED_APP_IS_PRIVATE,
|
||||
PATH_HAS_APP_ID,
|
||||
WEBAPP_AUTH_ENABLED,
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
from controllers.openapi.auth.steps import (
|
||||
AppAuthzCheck,
|
||||
AppResolver,
|
||||
BearerCheck,
|
||||
CallerMount,
|
||||
ScopeCheck,
|
||||
SurfaceCheck,
|
||||
WorkspaceMembershipCheck,
|
||||
)
|
||||
from controllers.openapi.auth.data import Edition
|
||||
from controllers.openapi.auth.flow import When
|
||||
from controllers.openapi.auth.pipeline import AuthPipeline, PipelineRoute, PipelineRouter
|
||||
from controllers.openapi.auth.prepare import (
|
||||
build_external_identity,
|
||||
load_account,
|
||||
load_app,
|
||||
load_app_access_mode,
|
||||
load_tenant,
|
||||
resolve_external_user,
|
||||
from controllers.openapi.auth.strategies import (
|
||||
AccountMounter,
|
||||
AclStrategy,
|
||||
AppAuthzStrategy,
|
||||
EndUserMounter,
|
||||
MembershipStrategy,
|
||||
)
|
||||
from controllers.openapi.auth.verify import (
|
||||
check_acl,
|
||||
check_app_access,
|
||||
check_membership,
|
||||
check_private_app_permission,
|
||||
check_scope,
|
||||
)
|
||||
from libs.oauth_bearer import TokenType
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
account_pipeline = AuthPipeline(
|
||||
prepare=[
|
||||
When(PATH_HAS_APP_ID, then=load_app),
|
||||
When(PATH_HAS_APP_ID, then=load_tenant),
|
||||
load_account, # unconditional — this IS the account pipeline
|
||||
When(PATH_HAS_APP_ID & EDITION_EE, then=load_app_access_mode),
|
||||
],
|
||||
auth=[
|
||||
check_scope,
|
||||
When(EDITION_CE & PATH_HAS_APP_ID, then=check_membership),
|
||||
When(EDITION_EE & PATH_HAS_APP_ID & ~WEBAPP_AUTH_ENABLED, then=check_app_access),
|
||||
When(PATH_HAS_APP_ID & EDITION_EE & WEBAPP_AUTH_ENABLED, then=check_acl),
|
||||
When(EDITION_EE & LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
|
||||
],
|
||||
)
|
||||
|
||||
external_sso_pipeline = AuthPipeline(
|
||||
prepare=[
|
||||
build_external_identity,
|
||||
When(PATH_HAS_APP_ID, then=load_app),
|
||||
When(PATH_HAS_APP_ID, then=load_tenant),
|
||||
When(PATH_HAS_APP_ID, then=resolve_external_user),
|
||||
When(PATH_HAS_APP_ID, then=load_app_access_mode),
|
||||
],
|
||||
auth=[
|
||||
check_scope,
|
||||
When(PATH_HAS_APP_ID & WEBAPP_AUTH_ENABLED, then=check_acl),
|
||||
When(LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
|
||||
],
|
||||
)
|
||||
def _resolve_app_authz_strategy() -> AppAuthzStrategy:
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
return AclStrategy()
|
||||
return MembershipStrategy()
|
||||
|
||||
auth_router = PipelineRouter({
|
||||
TokenType.OAUTH_ACCOUNT: PipelineRoute(account_pipeline),
|
||||
TokenType.OAUTH_EXTERNAL_SSO: PipelineRoute(external_sso_pipeline, required_edition=frozenset({Edition.EE})),
|
||||
})
|
||||
|
||||
OAUTH_BEARER_PIPELINE = Pipeline(
|
||||
BearerCheck(),
|
||||
SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT})),
|
||||
ScopeCheck(),
|
||||
AppResolver(),
|
||||
WorkspaceMembershipCheck(),
|
||||
AppAuthzCheck(_resolve_app_authz_strategy),
|
||||
CallerMount(AccountMounter(), EndUserMounter()),
|
||||
)
|
||||
|
||||
@ -1,53 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from controllers.openapi.auth.data import AuthData, Edition, RequestContext, current_edition
|
||||
from libs.oauth_bearer import TokenType
|
||||
from services.enterprise.enterprise_service import WebAppAccessMode
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
CondFn = Callable[[RequestContext, AuthData | None], bool]
|
||||
|
||||
|
||||
class Cond:
|
||||
def __init__(self, fn: CondFn) -> None:
|
||||
self._fn = fn
|
||||
|
||||
def __call__(self, ctx: RequestContext, data: AuthData | None = None) -> bool:
|
||||
return self._fn(ctx, data)
|
||||
|
||||
def __and__(self, other: Cond) -> Cond:
|
||||
return Cond(lambda ctx, data: self(ctx, data) and other(ctx, data))
|
||||
|
||||
def __or__(self, other: Cond) -> Cond:
|
||||
return Cond(lambda ctx, data: self(ctx, data) or other(ctx, data))
|
||||
|
||||
def __invert__(self) -> Cond:
|
||||
return Cond(lambda ctx, data: not self(ctx, data))
|
||||
|
||||
|
||||
def request_cond(fn: Callable[[RequestContext], bool]) -> Cond:
|
||||
return Cond(lambda ctx, _: fn(ctx))
|
||||
|
||||
|
||||
def data_cond(fn: Callable[[AuthData], bool]) -> Cond:
|
||||
return Cond(lambda _, data: data is not None and fn(data))
|
||||
|
||||
|
||||
def config_cond(fn: Callable[[], bool]) -> Cond:
|
||||
return Cond(lambda _, __: fn())
|
||||
|
||||
|
||||
TOKEN_IS_OAUTH_ACCOUNT = request_cond(lambda ctx: ctx.token_type == TokenType.OAUTH_ACCOUNT)
|
||||
TOKEN_IS_OAUTH_EXTERNAL_SSO = request_cond(lambda ctx: ctx.token_type == TokenType.OAUTH_EXTERNAL_SSO)
|
||||
|
||||
PATH_HAS_APP_ID = request_cond(lambda ctx: "app_id" in ctx.path_params)
|
||||
|
||||
EDITION_CE = config_cond(lambda: current_edition() == Edition.CE)
|
||||
EDITION_EE = config_cond(lambda: current_edition() == Edition.EE)
|
||||
EDITION_SAAS = config_cond(lambda: current_edition() == Edition.SAAS)
|
||||
|
||||
WEBAPP_AUTH_ENABLED = config_cond(lambda: FeatureService.get_system_features().webapp_auth.enabled)
|
||||
|
||||
LOADED_APP_IS_PRIVATE = data_cond(lambda data: data.app_access_mode == WebAppAccessMode.PRIVATE)
|
||||
68
api/controllers/openapi/auth/context.py
Normal file
68
api/controllers/openapi/auth/context.py
Normal file
@ -0,0 +1,68 @@
|
||||
"""Mutable per-request context for the openapi auth pipeline.
|
||||
|
||||
Every field starts None / empty and is filled in by a step. The pipeline
|
||||
is the only thing that should construct or mutate Context — handlers
|
||||
read populated values via the decorator's kwargs unpacking.
|
||||
|
||||
Context is intentionally decoupled from Flask's ``Request``: the pipeline
|
||||
guard extracts whatever transport-level inputs the steps need (bearer
|
||||
token, path params) at the boundary and writes them into Context fields,
|
||||
so steps stay testable without a request object and won't leak coupling
|
||||
to a specific framework.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from contextvars import Token
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Literal, Protocol
|
||||
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from libs.oauth_bearer import AuthContext, Scope, SubjectType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models import App, Tenant
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
required_scope: Scope
|
||||
bearer_token: str | None = None
|
||||
path_params: Mapping[str, str] = field(default_factory=dict)
|
||||
subject_type: SubjectType | None = None
|
||||
subject_email: str | None = None
|
||||
subject_issuer: str | None = None
|
||||
account_id: uuid.UUID | None = None
|
||||
scopes: frozenset[Scope] = field(default_factory=frozenset)
|
||||
token_id: uuid.UUID | None = None
|
||||
token_hash: str | None = None
|
||||
cached_verified_tenants: dict[str, bool] | None = None
|
||||
source: str | None = None
|
||||
expires_at: datetime | None = None
|
||||
app: App | None = None
|
||||
tenant: Tenant | None = None
|
||||
caller: object | None = None
|
||||
caller_kind: Literal["account", "end_user"] | None = None
|
||||
auth_ctx_reset_token: Token[AuthContext] | None = None
|
||||
|
||||
@property
|
||||
def must_tenant(self) -> Tenant:
|
||||
if not self.tenant:
|
||||
raise Unauthorized("tenant is not associated")
|
||||
return self.tenant
|
||||
|
||||
@property
|
||||
def must_subject_type(self) -> SubjectType:
|
||||
if not self.subject_type:
|
||||
raise Unauthorized("subject_type unset — BearerCheck did not run")
|
||||
return self.subject_type
|
||||
|
||||
|
||||
class Step(Protocol):
|
||||
"""One responsibility. Mutate ctx or raise to short-circuit."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None: ...
|
||||
@ -1,62 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from configs import dify_config
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models.account import Account, Tenant
|
||||
from models.model import App, EndUser
|
||||
from services.enterprise.enterprise_service import WebAppAccessMode
|
||||
|
||||
|
||||
class Edition(StrEnum):
|
||||
CE = "ce"
|
||||
EE = "ee"
|
||||
SAAS = "saas"
|
||||
|
||||
|
||||
def current_edition() -> Edition:
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
return Edition.SAAS
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
return Edition.EE
|
||||
return Edition.CE
|
||||
|
||||
|
||||
class ExternalIdentity(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
email: str
|
||||
issuer: str | None = None
|
||||
|
||||
|
||||
class RequestContext(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
token_type: TokenType
|
||||
scope: Scope | None = None
|
||||
path_params: dict[str, str]
|
||||
|
||||
|
||||
class AuthData(BaseModel):
|
||||
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
|
||||
|
||||
required_scope: Scope | None = None
|
||||
token_type: TokenType
|
||||
account_id: uuid.UUID | None = None
|
||||
token_hash: str
|
||||
token_id: uuid.UUID | None = None
|
||||
scopes: frozenset[Scope]
|
||||
tenants: dict[str, bool] = Field(default_factory=dict)
|
||||
external_identity: ExternalIdentity | None = None
|
||||
|
||||
app: App | None = None
|
||||
tenant: Tenant | None = None
|
||||
app_access_mode: WebAppAccessMode | None = None
|
||||
|
||||
caller: Account | EndUser | None = None
|
||||
caller_kind: Literal["account", "end_user"] | None = None
|
||||
@ -1,19 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from controllers.openapi.auth.conditions import Cond
|
||||
from controllers.openapi.auth.data import AuthData, RequestContext
|
||||
|
||||
|
||||
class When:
|
||||
def __init__(self, condition: Cond, *, then: Callable[[Any], None]) -> None:
|
||||
self.condition = condition
|
||||
self._step = then
|
||||
|
||||
def applies(self, ctx: RequestContext, data: AuthData | None = None) -> bool:
|
||||
return self.condition(ctx, data)
|
||||
|
||||
def __call__(self, arg: Any) -> None:
|
||||
self._step(arg)
|
||||
@ -1,224 +1,51 @@
|
||||
"""Auth pipeline — entry point for all openapi auth.
|
||||
"""Pipeline IS the auth scheme.
|
||||
|
||||
`PipelineRouter.guard()` is the only attachment point for endpoints.
|
||||
`AuthPipeline` is a pure step-runner with no routing concerns.
|
||||
`PipelineRoute` binds a pipeline to optional edition requirements.
|
||||
`Pipeline.guard(scope=…)` is the only attachment point for endpoints —
|
||||
that is the design lock-in: forgetting an auth layer is structurally
|
||||
impossible because there is no "sometimes wrap, sometimes don't" choice.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
from flask import request
|
||||
|
||||
from controllers.openapi._audit import emit_wrong_surface
|
||||
from controllers.openapi.auth.data import (
|
||||
AuthData,
|
||||
Edition,
|
||||
RequestContext,
|
||||
current_edition,
|
||||
)
|
||||
from controllers.openapi.auth.flow import When
|
||||
from libs.oauth_bearer import (
|
||||
AuthContext,
|
||||
Scope,
|
||||
TokenType,
|
||||
extract_bearer,
|
||||
get_authenticator,
|
||||
reset_auth_ctx,
|
||||
set_auth_ctx,
|
||||
)
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# New design: AuthPipeline / PipelineRoute / PipelineRouter
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.openapi.auth.context import Context, Step
|
||||
from libs.oauth_bearer import Scope, extract_bearer, reset_auth_ctx
|
||||
|
||||
|
||||
class AuthPipeline:
|
||||
"""Pure step-runner — no routing, no guard.
|
||||
class Pipeline:
|
||||
def __init__(self, *steps: Step) -> None:
|
||||
self._steps = steps
|
||||
|
||||
`prepare` steps receive a mutable builder dict (includes `path_params`).
|
||||
`auth` steps receive the fully constructed, frozen `AuthData`.
|
||||
"""
|
||||
def run(self, ctx: Context) -> None:
|
||||
for step in self._steps:
|
||||
step(ctx)
|
||||
|
||||
def __init__(self, prepare: list, auth: list) -> None:
|
||||
self._prepare = prepare
|
||||
self._auth = auth
|
||||
|
||||
def _run(
|
||||
self,
|
||||
identity: AuthContext,
|
||||
args: tuple,
|
||||
kwargs: dict,
|
||||
view: Callable,
|
||||
*,
|
||||
scope: Scope | None,
|
||||
) -> Any:
|
||||
req_ctx = RequestContext(
|
||||
token_type=identity.token_type,
|
||||
scope=scope,
|
||||
path_params=dict(request.view_args or {}),
|
||||
)
|
||||
|
||||
builder = _init_builder(identity, scope)
|
||||
builder["path_params"] = dict(req_ctx.path_params)
|
||||
|
||||
for step in self._prepare:
|
||||
if _should_run(step, req_ctx, data=None):
|
||||
step(builder)
|
||||
|
||||
builder.pop("path_params", None)
|
||||
builder.pop("_subject_email", None)
|
||||
builder.pop("_subject_issuer", None)
|
||||
data = AuthData(**builder)
|
||||
|
||||
for step in self._auth:
|
||||
if _should_run(step, req_ctx, data=data):
|
||||
step(data)
|
||||
|
||||
reset_token = set_auth_ctx(identity)
|
||||
if data.caller:
|
||||
_mount_flask_login(data.caller)
|
||||
|
||||
try:
|
||||
kwargs["auth_data"] = data
|
||||
return view(*args, **kwargs)
|
||||
finally:
|
||||
reset_auth_ctx(reset_token)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineRoute:
|
||||
pipeline: AuthPipeline
|
||||
required_edition: frozenset[Edition] | None = None
|
||||
|
||||
|
||||
class PipelineRouter:
|
||||
"""Entry point for openapi auth.
|
||||
|
||||
`guard()` is the decorator that endpoints attach to. It applies
|
||||
global gates (edition, token type) then dispatches to the matching
|
||||
`PipelineRoute` for the token type.
|
||||
"""
|
||||
|
||||
def __init__(self, routes: dict[TokenType, PipelineRoute]) -> None:
|
||||
self._routes = routes
|
||||
|
||||
def guard(
|
||||
self,
|
||||
*,
|
||||
scope: Scope | None = None,
|
||||
allowed_token_types: frozenset[TokenType] | None = None,
|
||||
edition: frozenset[Edition] | None = None,
|
||||
) -> Callable:
|
||||
def decorator(view: Callable) -> Callable:
|
||||
def guard(self, *, scope: Scope):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args: Any, **kwargs: Any) -> Any:
|
||||
return self._execute(
|
||||
args,
|
||||
kwargs,
|
||||
view,
|
||||
scope=scope,
|
||||
allowed_token_types=allowed_token_types,
|
||||
edition=edition,
|
||||
def decorated(*args, **kwargs):
|
||||
# Extract transport-level inputs at the boundary so steps
|
||||
# stay decoupled from Flask's request object.
|
||||
ctx = Context(
|
||||
required_scope=scope,
|
||||
bearer_token=extract_bearer(request),
|
||||
path_params=dict(request.view_args or {}),
|
||||
)
|
||||
try:
|
||||
self.run(ctx)
|
||||
kwargs.update(
|
||||
app_model=ctx.app,
|
||||
caller=ctx.caller,
|
||||
caller_kind=ctx.caller_kind,
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
finally:
|
||||
if ctx.auth_ctx_reset_token is not None:
|
||||
reset_auth_ctx(ctx.auth_ctx_reset_token)
|
||||
|
||||
return decorated
|
||||
|
||||
return decorator
|
||||
|
||||
def _execute(
|
||||
self,
|
||||
args: tuple,
|
||||
kwargs: dict,
|
||||
view: Callable,
|
||||
*,
|
||||
scope: Scope | None,
|
||||
allowed_token_types: frozenset[TokenType] | None,
|
||||
edition: frozenset[Edition] | None,
|
||||
) -> Any:
|
||||
# Gate 1: endpoint-level edition (404 — feature doesn't exist here)
|
||||
if edition is not None and current_edition() not in edition:
|
||||
raise NotFound()
|
||||
|
||||
# Gate 2: EE license for endpoint-level edition requirement
|
||||
if edition is not None and Edition.EE in edition:
|
||||
_check_license()
|
||||
|
||||
token = extract_bearer(request)
|
||||
if not token:
|
||||
raise Unauthorized("bearer required")
|
||||
|
||||
identity = get_authenticator().authenticate(token)
|
||||
|
||||
# Gate 3: endpoint-level token type allowlist (403)
|
||||
if allowed_token_types is not None and identity.token_type not in allowed_token_types:
|
||||
emit_wrong_surface(
|
||||
subject_type=_subject_type_str(identity),
|
||||
attempted_path=request.path,
|
||||
client_id=getattr(identity, "client_id", None),
|
||||
token_id=str(identity.token_id) if identity.token_id else None,
|
||||
)
|
||||
raise Forbidden("unsupported_token_type")
|
||||
|
||||
route = self._routes.get(identity.token_type)
|
||||
if route is None:
|
||||
raise Forbidden("unsupported_token_type")
|
||||
|
||||
# Gate 4: route-level edition invariant (token type requires EE)
|
||||
if route.required_edition is not None:
|
||||
if current_edition() not in route.required_edition:
|
||||
raise Forbidden("external_sso_requires_ee")
|
||||
if Edition.EE in route.required_edition:
|
||||
_check_license()
|
||||
|
||||
return route.pipeline._run(identity, args, kwargs, view, scope=scope)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _should_run(step: Any, req_ctx: RequestContext, data: AuthData | None) -> bool:
|
||||
if isinstance(step, When):
|
||||
return step.applies(req_ctx, data)
|
||||
return True
|
||||
|
||||
|
||||
def _init_builder(identity: AuthContext, scope: Scope | None) -> dict:
|
||||
return {
|
||||
"token_type": identity.token_type,
|
||||
"account_id": identity.account_id,
|
||||
"token_hash": identity.token_hash,
|
||||
"token_id": identity.token_id,
|
||||
"scopes": frozenset(identity.scopes),
|
||||
"tenants": dict(identity.verified_tenants),
|
||||
"required_scope": scope,
|
||||
"_subject_email": identity.subject_email,
|
||||
"_subject_issuer": identity.subject_issuer,
|
||||
}
|
||||
|
||||
|
||||
def _subject_type_str(identity: Any) -> str | None:
|
||||
subject = getattr(identity, "subject_type", None)
|
||||
if subject is None:
|
||||
return None
|
||||
return subject.value if hasattr(subject, "value") else str(subject)
|
||||
|
||||
|
||||
def _check_license() -> None:
|
||||
settings = FeatureService.get_system_features()
|
||||
if settings.license.status in {LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST}:
|
||||
raise Forbidden("license_invalid")
|
||||
|
||||
|
||||
def _mount_flask_login(user: Any) -> None:
|
||||
current_app.login_manager._update_request_context_with_user(user) # type: ignore[attr-defined]
|
||||
user_logged_in.send(current_app._get_current_object(), user=user) # type: ignore[attr-defined]
|
||||
|
||||
@ -1,78 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.data import ExternalIdentity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantStatus
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode
|
||||
|
||||
|
||||
def build_external_identity(builder: dict) -> None:
|
||||
email = builder.pop("_subject_email", None)
|
||||
issuer = builder.pop("_subject_issuer", None)
|
||||
if email:
|
||||
builder["external_identity"] = ExternalIdentity(email=email, issuer=issuer)
|
||||
|
||||
|
||||
def load_app(builder: dict) -> None:
|
||||
app_id = builder["path_params"]["app_id"]
|
||||
app = AppService.get_app_by_id(db.session, app_id)
|
||||
if not app or app.status != "normal":
|
||||
raise NotFound("app not found")
|
||||
if not app.enable_api:
|
||||
raise Forbidden("service_api_disabled")
|
||||
builder["app"] = app
|
||||
|
||||
|
||||
def load_tenant(builder: dict) -> None:
|
||||
app = builder["app"]
|
||||
tenant = TenantService.get_tenant_by_id(db.session, str(app.tenant_id))
|
||||
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden("workspace unavailable")
|
||||
builder["tenant"] = tenant
|
||||
|
||||
|
||||
def load_account(builder: dict) -> None:
|
||||
account = AccountService.get_account_by_id(db.session, str(builder["account_id"]))
|
||||
if account is None:
|
||||
raise Unauthorized("account not found")
|
||||
tenant = builder.get("tenant")
|
||||
if tenant:
|
||||
account.current_tenant = tenant
|
||||
builder["caller"] = account
|
||||
builder["caller_kind"] = "account"
|
||||
|
||||
|
||||
def resolve_external_user(builder: dict) -> None:
|
||||
tenant = builder.get("tenant")
|
||||
app = builder.get("app")
|
||||
ext: ExternalIdentity | None = builder.get("external_identity")
|
||||
if not all([tenant, app, ext]):
|
||||
raise Unauthorized("missing context for external user resolution")
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
InvokeFrom.OPENAPI,
|
||||
tenant_id=str(tenant.id), # type: ignore[union-attr]
|
||||
app_id=str(app.id), # type: ignore[union-attr]
|
||||
user_id=ext.email, # type: ignore[union-attr]
|
||||
)
|
||||
builder["caller"] = end_user
|
||||
builder["caller_kind"] = "end_user"
|
||||
|
||||
|
||||
def load_app_access_mode(builder: dict) -> None:
|
||||
app = builder.get("app")
|
||||
if app is None:
|
||||
return
|
||||
try:
|
||||
settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app.id))
|
||||
if settings is None:
|
||||
builder["app_access_mode"] = None
|
||||
return
|
||||
builder["app_access_mode"] = WebAppAccessMode(settings.access_mode)
|
||||
except ValueError:
|
||||
builder["app_access_mode"] = None
|
||||
170
api/controllers/openapi/auth/steps.py
Normal file
170
api/controllers/openapi/auth/steps.py
Normal file
@ -0,0 +1,170 @@
|
||||
"""Pipeline steps. Each is one responsibility.
|
||||
|
||||
`BearerCheck` is the only step that touches the token registry; downstream
|
||||
steps see only the populated `Context`. `BearerCheck` also publishes the
|
||||
resolved identity to the openapi auth ``ContextVar`` (the same one the
|
||||
decorator-level :func:`libs.oauth_bearer.validate_bearer` writes to) so the
|
||||
surface gate and any handler reading the request-scoped context has a single
|
||||
source of truth across both auth-attach paths. The reset token is stashed
|
||||
on `ctx.auth_ctx_reset_token`; `Pipeline.guard` resets the ContextVar in
|
||||
its `finally` so worker-thread reuse can't leak identity across requests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter
|
||||
from controllers.openapi.auth.surface_gate import check_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
AuthContext,
|
||||
InvalidBearerError,
|
||||
Scope,
|
||||
SubjectType,
|
||||
check_workspace_membership,
|
||||
get_authenticator,
|
||||
set_auth_ctx,
|
||||
)
|
||||
from models import TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppService
|
||||
|
||||
|
||||
class BearerCheck:
|
||||
"""Resolve bearer → populate identity fields. Rate-limit is enforced
|
||||
inside `BearerAuthenticator.authenticate`, so no separate step here.
|
||||
Also publishes the resolved `AuthContext` via
|
||||
:func:`libs.oauth_bearer.set_auth_ctx` — same shape the decorator-level
|
||||
``validate_bearer`` writes — so the surface gate + downstream readers
|
||||
don't see two different identity sources. The reset token is parked on
|
||||
``ctx.auth_ctx_reset_token`` for `Pipeline.guard` to consume."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if not ctx.bearer_token:
|
||||
raise Unauthorized("bearer required")
|
||||
|
||||
try:
|
||||
authn = get_authenticator().authenticate(ctx.bearer_token)
|
||||
except InvalidBearerError as e:
|
||||
raise Unauthorized(str(e))
|
||||
|
||||
ctx.subject_type = authn.subject_type
|
||||
ctx.subject_email = authn.subject_email
|
||||
ctx.subject_issuer = authn.subject_issuer
|
||||
ctx.account_id = authn.account_id
|
||||
ctx.scopes = frozenset(authn.scopes)
|
||||
ctx.source = authn.source
|
||||
ctx.token_id = authn.token_id
|
||||
ctx.expires_at = authn.expires_at
|
||||
ctx.token_hash = authn.token_hash
|
||||
ctx.cached_verified_tenants = dict(authn.verified_tenants)
|
||||
ctx.auth_ctx_reset_token = set_auth_ctx(authn)
|
||||
|
||||
|
||||
class ScopeCheck:
|
||||
"""Verify ctx.scopes (already populated by BearerCheck) covers required."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if Scope.FULL in ctx.scopes or ctx.required_scope in ctx.scopes:
|
||||
return
|
||||
raise Forbidden("insufficient_scope")
|
||||
|
||||
|
||||
class SurfaceCheck:
|
||||
"""Reject the request if the resolved subject is not in `accepted`."""
|
||||
|
||||
def __init__(self, *, accepted: frozenset[SubjectType]) -> None:
|
||||
self._accepted = accepted
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
check_surface(self._accepted)
|
||||
|
||||
|
||||
class AppResolver:
|
||||
"""Read ``app_id`` from ``ctx.path_params``; populate ctx.app + ctx.tenant.
|
||||
|
||||
Every endpoint using the OAuth bearer pipeline must declare
|
||||
``<string:app_id>`` in its route — that is the design lock-in (no body /
|
||||
header coupling). ``Pipeline.guard`` lifts ``request.view_args`` into
|
||||
``ctx.path_params`` at the boundary so this step doesn't need to know
|
||||
about the request object.
|
||||
"""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
app_id = ctx.path_params.get("app_id")
|
||||
if not app_id:
|
||||
raise BadRequest("app_id is required in path")
|
||||
app = AppService.get_app_by_id(db.session, app_id)
|
||||
if not app or app.status != "normal":
|
||||
raise NotFound("app not found")
|
||||
if not app.enable_api:
|
||||
raise Forbidden("service_api_disabled")
|
||||
tenant = TenantService.get_tenant_by_id(db.session, str(app.tenant_id))
|
||||
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden("workspace unavailable")
|
||||
ctx.app, ctx.tenant = app, tenant
|
||||
|
||||
|
||||
class WorkspaceMembershipCheck:
|
||||
"""Layer 0 — workspace membership gate.
|
||||
|
||||
CE-only (skipped when ENTERPRISE_ENABLED). Account-subject bearers
|
||||
(dfoa_) only — SSO subjects skip.
|
||||
"""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
return
|
||||
if ctx.subject_type != SubjectType.ACCOUNT:
|
||||
return
|
||||
if ctx.account_id is None or ctx.tenant is None:
|
||||
raise Unauthorized("account_id or tenant unset — BearerCheck or AppResolver did not run")
|
||||
if ctx.token_hash is None:
|
||||
raise Unauthorized("token_hash unset — BearerCheck did not run")
|
||||
|
||||
check_workspace_membership(
|
||||
account_id=ctx.account_id,
|
||||
tenant_id=ctx.must_tenant.id,
|
||||
token_hash=ctx.token_hash,
|
||||
cached_verdicts=ctx.cached_verified_tenants or {},
|
||||
)
|
||||
|
||||
|
||||
class AppAuthzCheck:
|
||||
def __init__(self, resolve_strategy: Callable[[], AppAuthzStrategy]) -> None:
|
||||
self._resolve = resolve_strategy
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if not self._resolve().authorize(ctx):
|
||||
raise Forbidden("subject_no_app_access")
|
||||
|
||||
|
||||
class CallerMount:
|
||||
def __init__(self, *mounters: CallerMounter) -> None:
|
||||
self._mounters = mounters
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if ctx.subject_type is None:
|
||||
raise Unauthorized("subject_type unset — BearerCheck did not run")
|
||||
for m in self._mounters:
|
||||
if m.applies_to(ctx.must_subject_type):
|
||||
m.mount(ctx)
|
||||
return
|
||||
raise Unauthorized("no caller mounter for subject type")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AppAuthzCheck",
|
||||
"AppResolver",
|
||||
"AuthContext",
|
||||
"BearerCheck",
|
||||
"CallerMount",
|
||||
"ScopeCheck",
|
||||
"SurfaceCheck",
|
||||
"WorkspaceMembershipCheck",
|
||||
]
|
||||
168
api/controllers/openapi/auth/strategies.py
Normal file
168
api/controllers/openapi/auth/strategies.py
Normal file
@ -0,0 +1,168 @@
|
||||
"""Strategy classes for the openapi auth pipeline.
|
||||
|
||||
App authorization (Acl/Membership) and caller mounting (Account/EndUser)
|
||||
vary along independent axes; each strategy is one class so the pipeline
|
||||
composition stays a flat list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from flask import current_app
|
||||
from flask_login import user_logged_in
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.enterprise.enterprise_service import (
|
||||
EnterpriseService,
|
||||
WebAppAccessMode,
|
||||
)
|
||||
|
||||
|
||||
class AppAuthzStrategy(Protocol):
|
||||
def authorize(self, ctx: Context) -> bool: ...
|
||||
|
||||
|
||||
class AclStrategy:
|
||||
"""Per-app ACL, evaluated in two stages.
|
||||
|
||||
The EE gateway has already enforced tenancy and workspace membership
|
||||
by the time this strategy runs, so AclStrategy only owns per-app ACL:
|
||||
|
||||
1. Subject vs access-mode compatibility (pure rule table). External-SSO
|
||||
bearers belong to public-facing apps only; account bearers cover the
|
||||
full set. A mismatch is an immediate deny — no IO.
|
||||
2. For modes that pair with the subject, decide whether the inner
|
||||
permission API must run. Only `PRIVATE` (per-app selected-user list)
|
||||
requires it; the remaining modes are pass-through.
|
||||
"""
|
||||
|
||||
_ALLOWED_MODES_BY_SUBJECT: dict[SubjectType, frozenset[WebAppAccessMode]] = {
|
||||
SubjectType.ACCOUNT: frozenset(
|
||||
{
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
WebAppAccessMode.PRIVATE_ALL,
|
||||
WebAppAccessMode.PRIVATE,
|
||||
}
|
||||
),
|
||||
SubjectType.EXTERNAL_SSO: frozenset(
|
||||
{
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
_MODES_REQUIRING_INNER_CHECK: frozenset[WebAppAccessMode] = frozenset({WebAppAccessMode.PRIVATE})
|
||||
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
if ctx.app is None:
|
||||
return False
|
||||
access_mode = self._fetch_access_mode(ctx.app.id)
|
||||
if access_mode is None:
|
||||
return False
|
||||
if not self._subject_allowed_for_mode(ctx.must_subject_type, access_mode):
|
||||
return False
|
||||
if access_mode not in self._MODES_REQUIRING_INNER_CHECK:
|
||||
return True
|
||||
return self._inner_permission_check(ctx)
|
||||
|
||||
@staticmethod
|
||||
def _fetch_access_mode(app_id: str) -> WebAppAccessMode | None:
|
||||
settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id)
|
||||
if settings is None:
|
||||
return None
|
||||
try:
|
||||
return WebAppAccessMode(settings.access_mode)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _subject_allowed_for_mode(cls, subject_type: SubjectType, access_mode: WebAppAccessMode) -> bool:
|
||||
return access_mode in cls._ALLOWED_MODES_BY_SUBJECT.get(subject_type, frozenset())
|
||||
|
||||
def _inner_permission_check(self, ctx: Context) -> bool:
|
||||
if ctx.app is None:
|
||||
return False
|
||||
user_id = self._resolve_user_id(ctx)
|
||||
if user_id is None:
|
||||
return False
|
||||
return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
user_id=user_id,
|
||||
app_id=ctx.app.id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_user_id(ctx: Context) -> str | None:
|
||||
if ctx.subject_type == SubjectType.ACCOUNT:
|
||||
return str(ctx.account_id) if ctx.account_id is not None else None
|
||||
if ctx.subject_email is None:
|
||||
return None
|
||||
account = AccountService.get_account_by_email(db.session, ctx.subject_email)
|
||||
return str(account.id) if account is not None else None
|
||||
|
||||
|
||||
class MembershipStrategy:
|
||||
"""Tenant-membership fallback.
|
||||
|
||||
Used when webapp-auth is disabled (CE deployment). Account-bearing
|
||||
subjects pass if they have a TenantAccountJoin row; EXTERNAL_SSO is
|
||||
denied (it requires the webapp-auth surface).
|
||||
"""
|
||||
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
return False
|
||||
if ctx.tenant is None:
|
||||
return False
|
||||
return TenantService.account_belongs_to_tenant(db.session, ctx.account_id, ctx.tenant.id)
|
||||
|
||||
|
||||
def _login_as(user) -> None:
|
||||
"""Set Flask-Login request user so downstream services see the caller."""
|
||||
current_app.login_manager._update_request_context_with_user(user) # type:ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=user) # type:ignore
|
||||
|
||||
|
||||
class CallerMounter(Protocol):
|
||||
def applies_to(self, subject_type: SubjectType) -> bool: ...
|
||||
|
||||
def mount(self, ctx: Context) -> None: ...
|
||||
|
||||
|
||||
class AccountMounter:
|
||||
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||
return subject_type == SubjectType.ACCOUNT
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
if ctx.account_id is None:
|
||||
raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run")
|
||||
account = AccountService.get_account_by_id(db.session, str(ctx.account_id))
|
||||
if account is None:
|
||||
raise RuntimeError("AccountMounter: account row missing for resolved bearer")
|
||||
account.current_tenant = ctx.must_tenant
|
||||
_login_as(account)
|
||||
ctx.caller, ctx.caller_kind = account, "account"
|
||||
|
||||
|
||||
class EndUserMounter:
|
||||
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||
return subject_type == SubjectType.EXTERNAL_SSO
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
if ctx.tenant is None or ctx.app is None or ctx.subject_email is None:
|
||||
raise RuntimeError("EndUserMounter: tenant/app/subject_email unset — earlier steps did not run")
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
InvokeFrom.OPENAPI,
|
||||
tenant_id=ctx.tenant.id,
|
||||
app_id=ctx.app.id,
|
||||
user_id=ctx.subject_email,
|
||||
)
|
||||
_login_as(end_user)
|
||||
ctx.caller, ctx.caller_kind = end_user, "end_user"
|
||||
89
api/controllers/openapi/auth/surface_gate.py
Normal file
89
api/controllers/openapi/auth/surface_gate.py
Normal file
@ -0,0 +1,89 @@
|
||||
"""Surface gate.
|
||||
|
||||
`@accept_subjects(...)` is the route-level form. `SurfaceCheck` (pipeline
|
||||
step) is the pipeline-level form. Both delegate to `check_surface` so the
|
||||
audit emit + canonical-path message are single-sourced.
|
||||
|
||||
Subjects come from `libs.oauth_bearer.SubjectType` directly — no parallel
|
||||
vocabulary. Caller hits the wrong surface → 403 ``wrong_surface`` + audit
|
||||
``openapi.wrong_surface_denied``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TypeVar
|
||||
|
||||
from flask import request
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi._audit import emit_wrong_surface
|
||||
from libs.oauth_bearer import SubjectType, try_get_auth_ctx
|
||||
|
||||
_CANONICAL_PATH: dict[SubjectType, str] = {
|
||||
SubjectType.ACCOUNT: "/openapi/v1/apps",
|
||||
SubjectType.EXTERNAL_SSO: "/openapi/v1/permitted-external-apps",
|
||||
}
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., object])
|
||||
|
||||
|
||||
def check_surface(accepted: frozenset[SubjectType]) -> None:
|
||||
"""Enforce that the resolved subject is in ``accepted``.
|
||||
|
||||
Reads the openapi auth ContextVar via :func:`try_get_auth_ctx`. Raises
|
||||
``Forbidden`` with ``wrong_surface`` + canonical-path hint on miss;
|
||||
emits ``openapi.wrong_surface_denied`` audit. If no auth context is
|
||||
set the bearer layer didn't run — that's a wiring bug, not a
|
||||
user-driven failure, so surface it as a ``RuntimeError`` instead of
|
||||
a silent 403.
|
||||
"""
|
||||
ctx = try_get_auth_ctx()
|
||||
if ctx is None:
|
||||
raise RuntimeError(
|
||||
"check_surface called without an auth context; stack validate_bearer or BearerCheck above the surface gate"
|
||||
)
|
||||
|
||||
subject = _coerce_subject_type(getattr(ctx, "subject_type", None))
|
||||
if subject in accepted:
|
||||
return
|
||||
|
||||
canonical = _CANONICAL_PATH.get(subject, "/openapi/v1/") if subject else "/openapi/v1/"
|
||||
emit_wrong_surface(
|
||||
subject_type=subject.value if subject else None,
|
||||
attempted_path=request.path,
|
||||
client_id=getattr(ctx, "client_id", None),
|
||||
token_id=_stringify(getattr(ctx, "token_id", None)),
|
||||
)
|
||||
raise Forbidden(description=f"wrong_surface (canonical: {canonical})")
|
||||
|
||||
|
||||
def accept_subjects(*accepted: SubjectType) -> Callable[[F], F]:
|
||||
accepted_set: frozenset[SubjectType] = frozenset(accepted)
|
||||
|
||||
def deco(fn: F) -> F:
|
||||
@wraps(fn)
|
||||
def wrapper(*args: object, **kwargs: object) -> object:
|
||||
check_surface(accepted_set)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
|
||||
return deco
|
||||
|
||||
|
||||
def _coerce_subject_type(raw: object) -> SubjectType | None:
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, SubjectType):
|
||||
return raw
|
||||
if isinstance(raw, str):
|
||||
return SubjectType(raw)
|
||||
return None
|
||||
|
||||
|
||||
def _stringify(value: object) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
@ -1,82 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from werkzeug.exceptions import Forbidden, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope, TokenType, check_workspace_membership
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode
|
||||
|
||||
|
||||
def check_scope(data: AuthData) -> None:
|
||||
if data.required_scope is None:
|
||||
return
|
||||
if Scope.FULL in data.scopes or data.required_scope in data.scopes:
|
||||
return
|
||||
raise Forbidden("insufficient_scope")
|
||||
|
||||
|
||||
# reject_external_sso removed — PipelineRouter._execute raises Forbidden("external_sso_requires_ee")
|
||||
# directly when route.required_edition is not satisfied. Not a pipeline step.
|
||||
|
||||
|
||||
def check_membership(data: AuthData) -> None:
|
||||
if data.tenant is None:
|
||||
raise Unauthorized("tenant unset")
|
||||
check_workspace_membership(
|
||||
account_id=data.account_id,
|
||||
tenant_id=data.tenant.id,
|
||||
token_hash=data.token_hash,
|
||||
membership_cache=data.tenants,
|
||||
)
|
||||
|
||||
|
||||
def check_app_access(data: AuthData) -> None:
|
||||
if data.tenant is None:
|
||||
return
|
||||
if not TenantService.account_belongs_to_tenant(db.session, data.account_id, data.tenant.id):
|
||||
raise Forbidden("subject_no_app_access")
|
||||
|
||||
|
||||
_ALLOWED_MODES_BY_TOKEN_TYPE: dict[TokenType, frozenset[WebAppAccessMode]] = {
|
||||
TokenType.OAUTH_ACCOUNT: frozenset({
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
WebAppAccessMode.PRIVATE_ALL,
|
||||
WebAppAccessMode.PRIVATE,
|
||||
}),
|
||||
TokenType.OAUTH_EXTERNAL_SSO: frozenset({
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
}),
|
||||
}
|
||||
|
||||
|
||||
def check_acl(data: AuthData) -> None:
|
||||
if data.app is None or data.app_access_mode is None:
|
||||
raise Forbidden("app or access mode not loaded")
|
||||
allowed_modes = _ALLOWED_MODES_BY_TOKEN_TYPE.get(data.token_type, frozenset())
|
||||
if data.app_access_mode not in allowed_modes:
|
||||
raise Forbidden("subject_not_allowed_for_access_mode")
|
||||
|
||||
|
||||
def check_private_app_permission(data: AuthData) -> None:
|
||||
if data.app is None:
|
||||
raise Forbidden("app not loaded")
|
||||
user_id = _resolve_user_id(data)
|
||||
if user_id is None:
|
||||
raise Forbidden("cannot resolve user for private app check")
|
||||
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
user_id=user_id, app_id=data.app.id
|
||||
):
|
||||
raise Forbidden("user_not_allowed_for_private_app")
|
||||
|
||||
|
||||
def _resolve_user_id(data: AuthData) -> str | None:
|
||||
if data.token_type == TokenType.OAUTH_ACCOUNT:
|
||||
return str(data.account_id) if data.account_id is not None else None
|
||||
if data.external_identity is None:
|
||||
return None
|
||||
account = AccountService.get_account_by_email(db.session, data.external_identity.email)
|
||||
return str(account.id) if account is not None else None
|
||||
@ -17,11 +17,11 @@ from controllers.common.errors import (
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileResponse
|
||||
from libs.oauth_bearer import Scope
|
||||
from models import Account, App
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@ -39,11 +39,8 @@ class AppFileUploadApi(Resource):
|
||||
}
|
||||
)
|
||||
@openapi_ns.response(HTTPStatus.CREATED, "File uploaded", openapi_ns.models[FileResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, *, auth_data: AuthData):
|
||||
app_model = auth_data.app
|
||||
caller = auth_data.caller
|
||||
caller_kind = auth_data.caller_kind
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, app_model: App, caller: Account, caller_kind: str):
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
if len(request.files) > 1:
|
||||
|
||||
@ -17,8 +17,7 @@ from werkzeug.exceptions import BadRequest, NotFound
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import to_timestamp
|
||||
@ -56,11 +55,8 @@ def _ensure_form_is_allowed_for_openapi(form) -> None:
|
||||
@openapi_ns.route("/apps/<string:app_id>/form/human_input/<string:form_token>")
|
||||
class OpenApiWorkflowHumanInputFormApi(Resource):
|
||||
@openapi_ns.response(200, "Form definition")
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def get(self, app_id: str, form_token: str, *, auth_data: AuthData):
|
||||
app_model = auth_data.app
|
||||
caller = auth_data.caller
|
||||
caller_kind = auth_data.caller_kind
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def get(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str):
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
@ -73,11 +69,8 @@ class OpenApiWorkflowHumanInputFormApi(Resource):
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
@openapi_ns.response(200, "Form submitted")
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, form_token: str, *, auth_data: AuthData):
|
||||
app_model = auth_data.app
|
||||
caller = auth_data.caller
|
||||
caller_kind = auth_data.caller_kind
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str):
|
||||
payload = HumanInputFormSubmitPayload.model_validate(request.get_json(silent=True) or {})
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
|
||||
@ -17,8 +17,7 @@ from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
@ -29,7 +28,7 @@ from core.workflow.human_input_policy import HumanInputSurface
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from models.model import App, AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
|
||||
@ -37,11 +36,8 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/events")
|
||||
class OpenApiWorkflowEventsApi(Resource):
|
||||
@openapi_ns.response(200, "SSE event stream")
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def get(self, app_id: str, task_id: str, *, auth_data: AuthData):
|
||||
app_model = auth_data.app
|
||||
caller = auth_data.caller
|
||||
caller_kind = auth_data.caller_kind
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def get(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
|
||||
raise UnprocessableEntity("mode_not_supported_for_event_reconnect")
|
||||
|
||||
@ -15,10 +15,14 @@ from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import WorkspaceDetailResponse, WorkspaceListResponse, WorkspaceSummaryResponse
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
SubjectType,
|
||||
get_auth_ctx,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import Tenant, TenantAccountJoin
|
||||
from services.account_service import TenantService
|
||||
|
||||
@ -26,9 +30,12 @@ from services.account_service import TenantService
|
||||
@openapi_ns.route("/workspaces")
|
||||
class WorkspacesApi(Resource):
|
||||
@openapi_ns.response(200, "Workspace list", openapi_ns.models[WorkspaceListResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
rows = TenantService.get_workspaces_for_account(db.session, str(auth_data.account_id))
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
def get(self):
|
||||
ctx = get_auth_ctx()
|
||||
|
||||
rows = TenantService.get_workspaces_for_account(db.session, str(ctx.account_id))
|
||||
|
||||
return WorkspaceListResponse(workspaces=list(starmap(_workspace_summary, rows))).model_dump(mode="json"), 200
|
||||
|
||||
@ -36,9 +43,12 @@ class WorkspacesApi(Resource):
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>")
|
||||
class WorkspaceByIdApi(Resource):
|
||||
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, workspace_id: str, *, auth_data: AuthData):
|
||||
row = TenantService.find_workspace_for_account(db.session, str(auth_data.account_id), workspace_id)
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
def get(self, workspace_id: str):
|
||||
ctx = get_auth_ctx()
|
||||
|
||||
row = TenantService.find_workspace_for_account(db.session, str(ctx.account_id), workspace_id)
|
||||
# 404 (not 403) on non-member so workspace IDs don't leak across tenants.
|
||||
if row is None:
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
@ -43,11 +43,6 @@ class SubjectType(StrEnum):
|
||||
EXTERNAL_SSO = "external_sso"
|
||||
|
||||
|
||||
class TokenType(StrEnum):
|
||||
OAUTH_ACCOUNT = "oauth_account"
|
||||
OAUTH_EXTERNAL_SSO = "oauth_external_sso"
|
||||
|
||||
|
||||
class Scope(StrEnum):
|
||||
"""Catalog of bearer scopes recognised by the openapi surface.
|
||||
|
||||
@ -60,7 +55,6 @@ class Scope(StrEnum):
|
||||
APPS_READ = "apps:read"
|
||||
APPS_READ_PERMITTED_EXTERNAL = "apps:read:permitted-external"
|
||||
APPS_RUN = "apps:run"
|
||||
WORKSPACE_READ = "workspace:read"
|
||||
|
||||
|
||||
class Accepts(StrEnum):
|
||||
@ -83,7 +77,7 @@ _SUBJECT_TO_ACCEPT: dict[SubjectType, Accepts] = {
|
||||
class AuthContext:
|
||||
"""Per-request identity published via :data:`_auth_ctx_var`
|
||||
(see :func:`set_auth_ctx` / :func:`get_auth_ctx`). ``scopes`` /
|
||||
``subject_type`` / ``token_type`` come from the TokenKind, not the DB —
|
||||
``subject_type`` / ``source`` come from the TokenKind, not the DB —
|
||||
corrupt rows can't elevate scope.
|
||||
|
||||
`verified_tenants` is a snapshot of the Layer-0 verdict cache at
|
||||
@ -98,7 +92,7 @@ class AuthContext:
|
||||
client_id: str | None
|
||||
scopes: frozenset[Scope]
|
||||
token_id: uuid.UUID
|
||||
token_type: TokenType
|
||||
source: str
|
||||
expires_at: datetime | None
|
||||
token_hash: str
|
||||
verified_tenants: dict[str, bool] = field(default_factory=dict)
|
||||
@ -186,7 +180,7 @@ class TokenKind:
|
||||
prefix: str
|
||||
subject_type: SubjectType
|
||||
scopes: frozenset[Scope]
|
||||
token_type: TokenType
|
||||
source: str
|
||||
resolver: Resolver
|
||||
|
||||
def matches(self, token: str) -> bool:
|
||||
@ -297,7 +291,7 @@ class BearerAuthenticator:
|
||||
client_id=row.client_id,
|
||||
scopes=kind.scopes,
|
||||
token_id=row.token_id,
|
||||
token_type=kind.token_type,
|
||||
source=kind.source,
|
||||
expires_at=row.expires_at,
|
||||
token_hash=token_hash,
|
||||
verified_tenants=dict(row.verified_tenants),
|
||||
@ -489,7 +483,7 @@ def check_workspace_membership(
|
||||
account_id: uuid.UUID | str,
|
||||
tenant_id: str,
|
||||
token_hash: str,
|
||||
membership_cache: dict[str, bool],
|
||||
cached_verdicts: dict[str, bool],
|
||||
) -> None:
|
||||
"""Layer-0 enforcement core. Raises `Forbidden` on deny, returns on allow.
|
||||
|
||||
@ -498,7 +492,7 @@ def check_workspace_membership(
|
||||
short-circuiting on EE / SSO subjects before invoking — this function
|
||||
runs the membership + active-status checks unconditionally.
|
||||
"""
|
||||
cached = membership_cache.get(tenant_id)
|
||||
cached = cached_verdicts.get(tenant_id)
|
||||
if cached is True:
|
||||
return
|
||||
if cached is False:
|
||||
@ -536,7 +530,7 @@ def require_workspace_member(ctx: AuthContext, tenant_id: str) -> None:
|
||||
account_id=ctx.account_id,
|
||||
tenant_id=tenant_id,
|
||||
token_hash=ctx.token_hash,
|
||||
membership_cache=ctx.verified_tenants,
|
||||
cached_verdicts=ctx.verified_tenants,
|
||||
)
|
||||
|
||||
|
||||
@ -670,14 +664,14 @@ def build_registry(session_factory, redis_client) -> TokenKindRegistry:
|
||||
prefix=account.prefix,
|
||||
subject_type=account.subject_type,
|
||||
scopes=account.scopes,
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
source="oauth_account",
|
||||
resolver=oauth.for_account(),
|
||||
),
|
||||
TokenKind(
|
||||
prefix=external.prefix,
|
||||
subject_type=external.subject_type,
|
||||
scopes=external.scopes,
|
||||
token_type=TokenType.OAUTH_EXTERNAL_SSO,
|
||||
source="oauth_external_sso",
|
||||
resolver=oauth.for_external_sso(),
|
||||
),
|
||||
]
|
||||
|
||||
@ -1,80 +1,66 @@
|
||||
from controllers.openapi.auth.composition import account_pipeline, auth_router, external_sso_pipeline
|
||||
from controllers.openapi.auth.flow import When
|
||||
from controllers.openapi.auth.pipeline import AuthPipeline, PipelineRoute, PipelineRouter
|
||||
from libs.oauth_bearer import TokenType
|
||||
from unittest.mock import patch
|
||||
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE, _resolve_app_authz_strategy
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
from controllers.openapi.auth.steps import (
|
||||
AppAuthzCheck,
|
||||
AppResolver,
|
||||
BearerCheck,
|
||||
CallerMount,
|
||||
ScopeCheck,
|
||||
SurfaceCheck,
|
||||
WorkspaceMembershipCheck,
|
||||
)
|
||||
from controllers.openapi.auth.strategies import (
|
||||
AccountMounter,
|
||||
AclStrategy,
|
||||
EndUserMounter,
|
||||
MembershipStrategy,
|
||||
)
|
||||
from libs.oauth_bearer import SubjectType
|
||||
|
||||
|
||||
def test_account_pipeline_is_auth_pipeline():
|
||||
assert isinstance(account_pipeline, AuthPipeline)
|
||||
def test_pipeline_is_composed():
|
||||
assert isinstance(OAUTH_BEARER_PIPELINE, Pipeline)
|
||||
|
||||
|
||||
def test_external_sso_pipeline_is_auth_pipeline():
|
||||
assert isinstance(external_sso_pipeline, AuthPipeline)
|
||||
def test_pipeline_step_order():
|
||||
"""BearerCheck → SurfaceCheck → ScopeCheck → AppResolver →
|
||||
WorkspaceMembershipCheck → AppAuthzCheck → CallerMount.
|
||||
SurfaceCheck enforces the dfoa_/dfoe_ surface split + emits
|
||||
`openapi.wrong_surface_denied`. Rate-limit is enforced inside
|
||||
`BearerAuthenticator.authenticate`, not as a separate pipeline step."""
|
||||
steps = OAUTH_BEARER_PIPELINE._steps
|
||||
assert isinstance(steps[0], BearerCheck)
|
||||
assert isinstance(steps[1], SurfaceCheck)
|
||||
assert isinstance(steps[2], ScopeCheck)
|
||||
assert isinstance(steps[3], AppResolver)
|
||||
assert isinstance(steps[4], WorkspaceMembershipCheck)
|
||||
assert isinstance(steps[5], AppAuthzCheck)
|
||||
assert isinstance(steps[6], CallerMount)
|
||||
|
||||
|
||||
def test_auth_router_is_pipeline_router():
|
||||
assert isinstance(auth_router, PipelineRouter)
|
||||
def test_pipeline_surface_check_accepts_account_only():
|
||||
"""Current pipeline serves /apps/<id>/run — account surface only."""
|
||||
surface = OAUTH_BEARER_PIPELINE._steps[1]
|
||||
assert isinstance(surface, SurfaceCheck)
|
||||
assert surface._accepted == frozenset({SubjectType.ACCOUNT})
|
||||
|
||||
|
||||
def test_account_pipeline_prepare_has_four_entries():
|
||||
assert len(account_pipeline._prepare) == 4
|
||||
def test_caller_mount_has_both_mounters():
|
||||
cm = OAUTH_BEARER_PIPELINE._steps[6]
|
||||
kinds = {type(m) for m in cm._mounters}
|
||||
assert AccountMounter in kinds
|
||||
assert EndUserMounter in kinds
|
||||
|
||||
|
||||
def test_account_auth_list_has_five_entries():
|
||||
assert len(account_pipeline._auth) == 5
|
||||
@patch("controllers.openapi.auth.composition.FeatureService")
|
||||
def test_strategy_resolver_picks_acl_when_enabled(fs):
|
||||
fs.get_system_features.return_value.webapp_auth.enabled = True
|
||||
assert isinstance(_resolve_app_authz_strategy(), AclStrategy)
|
||||
|
||||
|
||||
def test_external_sso_pipeline_prepare_has_five_entries():
|
||||
assert len(external_sso_pipeline._prepare) == 5
|
||||
|
||||
|
||||
def test_external_sso_auth_list_has_three_entries():
|
||||
# check_scope (unconditional) + 2 When entries
|
||||
assert len(external_sso_pipeline._auth) == 3
|
||||
|
||||
|
||||
def test_account_pipeline_has_unconditional_load_account():
|
||||
# load_account is the only bare (non-When) entry in account prepare
|
||||
non_when = [s for s in account_pipeline._prepare if not isinstance(s, When)]
|
||||
assert len(non_when) == 1
|
||||
|
||||
|
||||
def test_external_sso_pipeline_first_prepare_is_build_external_identity():
|
||||
from controllers.openapi.auth.prepare import build_external_identity
|
||||
assert external_sso_pipeline._prepare[0] is build_external_identity
|
||||
|
||||
|
||||
def test_external_sso_pipeline_remaining_prepare_entries_are_when():
|
||||
assert all(isinstance(s, When) for s in external_sso_pipeline._prepare[1:])
|
||||
|
||||
|
||||
def test_first_auth_entry_is_check_scope_in_both_pipelines():
|
||||
# check_scope is unconditional (not a When) and comes first in auth
|
||||
assert not isinstance(account_pipeline._auth[0], When)
|
||||
assert not isinstance(external_sso_pipeline._auth[0], When)
|
||||
|
||||
|
||||
def test_remaining_auth_entries_are_when_for_account():
|
||||
assert all(isinstance(s, When) for s in account_pipeline._auth[1:])
|
||||
|
||||
|
||||
def test_remaining_auth_entries_are_when_for_external_sso():
|
||||
assert all(isinstance(s, When) for s in external_sso_pipeline._auth[1:])
|
||||
|
||||
|
||||
def test_router_routes_contain_both_token_types():
|
||||
assert TokenType.OAUTH_ACCOUNT in auth_router._routes
|
||||
assert TokenType.OAUTH_EXTERNAL_SSO in auth_router._routes
|
||||
|
||||
|
||||
def test_external_sso_route_has_ee_required_edition():
|
||||
route = auth_router._routes[TokenType.OAUTH_EXTERNAL_SSO]
|
||||
assert isinstance(route, PipelineRoute)
|
||||
from controllers.openapi.auth.data import Edition
|
||||
assert route.required_edition == frozenset({Edition.EE})
|
||||
|
||||
|
||||
def test_account_route_has_no_required_edition():
|
||||
route = auth_router._routes[TokenType.OAUTH_ACCOUNT]
|
||||
assert isinstance(route, PipelineRoute)
|
||||
assert route.required_edition is None
|
||||
@patch("controllers.openapi.auth.composition.FeatureService")
|
||||
def test_strategy_resolver_picks_membership_when_disabled(fs):
|
||||
fs.get_system_features.return_value.webapp_auth.enabled = False
|
||||
assert isinstance(_resolve_app_authz_strategy(), MembershipStrategy)
|
||||
|
||||
@ -1,149 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from controllers.openapi.auth.conditions import (
|
||||
EDITION_CE,
|
||||
EDITION_EE,
|
||||
EDITION_SAAS,
|
||||
LOADED_APP_IS_PRIVATE,
|
||||
PATH_HAS_APP_ID,
|
||||
TOKEN_IS_OAUTH_ACCOUNT,
|
||||
TOKEN_IS_OAUTH_EXTERNAL_SSO,
|
||||
WEBAPP_AUTH_ENABLED,
|
||||
Cond,
|
||||
config_cond,
|
||||
data_cond,
|
||||
request_cond,
|
||||
)
|
||||
from controllers.openapi.auth.data import AuthData, Edition, RequestContext
|
||||
from libs.oauth_bearer import TokenType
|
||||
from services.enterprise.enterprise_service import WebAppAccessMode
|
||||
|
||||
|
||||
def _ctx(token_type=TokenType.OAUTH_ACCOUNT, path_params=None):
|
||||
return RequestContext(
|
||||
token_type=token_type,
|
||||
path_params=path_params or {},
|
||||
)
|
||||
|
||||
|
||||
def _data(**kwargs):
|
||||
defaults: dict = {"token_type": TokenType.OAUTH_ACCOUNT, "token_hash": "x", "scopes": frozenset()}
|
||||
defaults.update(kwargs)
|
||||
return AuthData(**defaults)
|
||||
|
||||
|
||||
# --- Cond operators ---
|
||||
|
||||
def test_and_both_true():
|
||||
a = Cond(lambda ctx, _: True)
|
||||
b = Cond(lambda ctx, _: True)
|
||||
assert (a & b)(_ctx()) is True
|
||||
|
||||
|
||||
def test_and_one_false():
|
||||
a = Cond(lambda ctx, _: True)
|
||||
b = Cond(lambda ctx, _: False)
|
||||
assert (a & b)(_ctx()) is False
|
||||
|
||||
|
||||
def test_or_one_true():
|
||||
a = Cond(lambda ctx, _: False)
|
||||
b = Cond(lambda ctx, _: True)
|
||||
assert (a | b)(_ctx()) is True
|
||||
|
||||
|
||||
def test_or_both_false():
|
||||
a = Cond(lambda ctx, _: False)
|
||||
b = Cond(lambda ctx, _: False)
|
||||
assert (a | b)(_ctx()) is False
|
||||
|
||||
|
||||
def test_invert():
|
||||
a = Cond(lambda ctx, _: True)
|
||||
assert (~a)(_ctx()) is False
|
||||
|
||||
|
||||
def test_chain_and_or():
|
||||
always_true = Cond(lambda ctx, _: True)
|
||||
always_false = Cond(lambda ctx, _: False)
|
||||
assert ((always_true | always_false) & always_true)(_ctx()) is True
|
||||
|
||||
|
||||
# --- Helper constructors ---
|
||||
|
||||
def test_request_cond_ignores_data():
|
||||
c = request_cond(lambda ctx: ctx.token_type == TokenType.OAUTH_ACCOUNT)
|
||||
assert c(_ctx(TokenType.OAUTH_ACCOUNT)) is True
|
||||
assert c(_ctx(TokenType.OAUTH_EXTERNAL_SSO)) is False
|
||||
|
||||
|
||||
def test_data_cond_returns_false_when_data_none():
|
||||
c = data_cond(lambda data: True)
|
||||
assert c(_ctx(), None) is False
|
||||
|
||||
|
||||
def test_data_cond_evaluates_when_data_present():
|
||||
c = data_cond(lambda data: data.token_hash == "secret")
|
||||
assert c(_ctx(), _data(token_hash="secret")) is True
|
||||
assert c(_ctx(), _data(token_hash="other")) is False
|
||||
|
||||
|
||||
def test_config_cond_ignores_ctx_and_data():
|
||||
c = config_cond(lambda: True)
|
||||
assert c(_ctx()) is True
|
||||
c2 = config_cond(lambda: False)
|
||||
assert c2(_ctx(), _data()) is False
|
||||
|
||||
|
||||
# --- Pre-built conditions ---
|
||||
|
||||
def test_token_is_oauth_account():
|
||||
assert TOKEN_IS_OAUTH_ACCOUNT(_ctx(TokenType.OAUTH_ACCOUNT)) is True
|
||||
assert TOKEN_IS_OAUTH_ACCOUNT(_ctx(TokenType.OAUTH_EXTERNAL_SSO)) is False
|
||||
|
||||
|
||||
def test_token_is_oauth_external_sso():
|
||||
assert TOKEN_IS_OAUTH_EXTERNAL_SSO(_ctx(TokenType.OAUTH_EXTERNAL_SSO)) is True
|
||||
|
||||
|
||||
def test_path_has_app_id_true():
|
||||
assert PATH_HAS_APP_ID(_ctx(path_params={"app_id": "abc"})) is True
|
||||
|
||||
|
||||
def test_path_has_app_id_false():
|
||||
assert PATH_HAS_APP_ID(_ctx(path_params={})) is False
|
||||
|
||||
|
||||
def test_edition_ce():
|
||||
with patch("controllers.openapi.auth.conditions.current_edition", return_value=Edition.CE):
|
||||
assert EDITION_CE(_ctx()) is True
|
||||
assert EDITION_EE(_ctx()) is False
|
||||
assert EDITION_SAAS(_ctx()) is False
|
||||
|
||||
|
||||
def test_edition_ee():
|
||||
with patch("controllers.openapi.auth.conditions.current_edition", return_value=Edition.EE):
|
||||
assert EDITION_EE(_ctx()) is True
|
||||
assert EDITION_CE(_ctx()) is False
|
||||
|
||||
|
||||
def test_edition_saas():
|
||||
with patch("controllers.openapi.auth.conditions.current_edition", return_value=Edition.SAAS):
|
||||
assert EDITION_SAAS(_ctx()) is True
|
||||
|
||||
|
||||
def test_webapp_auth_enabled():
|
||||
mock_features = MagicMock()
|
||||
mock_features.webapp_auth.enabled = True
|
||||
with patch("controllers.openapi.auth.conditions.FeatureService.get_system_features", return_value=mock_features):
|
||||
assert WEBAPP_AUTH_ENABLED(_ctx()) is True
|
||||
|
||||
|
||||
def test_loaded_app_is_private():
|
||||
data_private = _data(app_access_mode=WebAppAccessMode.PRIVATE)
|
||||
data_public = _data(app_access_mode=WebAppAccessMode.PUBLIC)
|
||||
data_none = _data(app_access_mode=None)
|
||||
assert LOADED_APP_IS_PRIVATE(_ctx(), data_private) is True
|
||||
assert LOADED_APP_IS_PRIVATE(_ctx(), data_public) is False
|
||||
assert LOADED_APP_IS_PRIVATE(_ctx(), data_none) is False
|
||||
assert LOADED_APP_IS_PRIVATE(_ctx(), None) is False
|
||||
@ -0,0 +1,21 @@
|
||||
from controllers.openapi.auth.context import Context
|
||||
|
||||
|
||||
def test_context_starts_unpopulated():
|
||||
ctx = Context(required_scope="apps:run")
|
||||
assert ctx.bearer_token is None
|
||||
assert ctx.path_params == {}
|
||||
assert ctx.subject_type is None
|
||||
assert ctx.subject_email is None
|
||||
assert ctx.account_id is None
|
||||
assert ctx.scopes == frozenset()
|
||||
assert ctx.app is None
|
||||
assert ctx.tenant is None
|
||||
assert ctx.caller is None
|
||||
assert ctx.caller_kind is None
|
||||
|
||||
|
||||
def test_context_fields_are_mutable():
|
||||
ctx = Context(required_scope="apps:run")
|
||||
ctx.scopes = frozenset({"full"})
|
||||
assert "full" in ctx.scopes
|
||||
@ -1,116 +0,0 @@
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.openapi.auth.data import (
|
||||
AuthData,
|
||||
Edition,
|
||||
ExternalIdentity,
|
||||
RequestContext,
|
||||
current_edition,
|
||||
)
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
|
||||
# --- Edition / current_edition ---
|
||||
|
||||
|
||||
def test_current_edition_saas():
|
||||
with patch("controllers.openapi.auth.data.dify_config") as cfg:
|
||||
cfg.EDITION = "CLOUD"
|
||||
cfg.ENTERPRISE_ENABLED = True
|
||||
assert current_edition() == Edition.SAAS
|
||||
|
||||
|
||||
def test_current_edition_ee():
|
||||
with patch("controllers.openapi.auth.data.dify_config") as cfg:
|
||||
cfg.EDITION = "SELF_HOSTED"
|
||||
cfg.ENTERPRISE_ENABLED = True
|
||||
assert current_edition() == Edition.EE
|
||||
|
||||
|
||||
def test_current_edition_ce():
|
||||
with patch("controllers.openapi.auth.data.dify_config") as cfg:
|
||||
cfg.EDITION = "SELF_HOSTED"
|
||||
cfg.ENTERPRISE_ENABLED = False
|
||||
assert current_edition() == Edition.CE
|
||||
|
||||
|
||||
# --- ExternalIdentity ---
|
||||
|
||||
def test_external_identity_frozen():
|
||||
ei = ExternalIdentity(email="a@b.com", issuer="idp")
|
||||
with pytest.raises(ValidationError):
|
||||
ei.email = "other@b.com" # type: ignore[misc]
|
||||
|
||||
|
||||
def test_external_identity_issuer_optional():
|
||||
ei = ExternalIdentity(email="a@b.com")
|
||||
assert ei.issuer is None
|
||||
|
||||
|
||||
# --- RequestContext ---
|
||||
|
||||
def test_request_context_frozen():
|
||||
ctx = RequestContext(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
path_params={"app_id": "123"},
|
||||
)
|
||||
with pytest.raises(ValidationError):
|
||||
ctx.token_type = TokenType.OAUTH_EXTERNAL_SSO # type: ignore[misc]
|
||||
|
||||
|
||||
def test_request_context_scope_optional():
|
||||
ctx = RequestContext(token_type=TokenType.OAUTH_ACCOUNT, path_params={})
|
||||
assert ctx.scope is None
|
||||
|
||||
|
||||
# --- AuthData ---
|
||||
|
||||
def test_auth_data_frozen():
|
||||
data = AuthData(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
token_hash="abc",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
)
|
||||
with pytest.raises(ValidationError):
|
||||
data.token_type = TokenType.OAUTH_EXTERNAL_SSO # type: ignore[misc]
|
||||
|
||||
|
||||
def test_auth_data_account_id_optional():
|
||||
data = AuthData(
|
||||
token_type=TokenType.OAUTH_EXTERNAL_SSO,
|
||||
token_hash="abc",
|
||||
scopes=frozenset({Scope.APPS_RUN}),
|
||||
external_identity=ExternalIdentity(email="u@sso.com"),
|
||||
)
|
||||
assert data.account_id is None
|
||||
|
||||
|
||||
def test_auth_data_external_identity_none_for_account():
|
||||
data = AuthData(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
account_id=uuid.uuid4(),
|
||||
token_hash="abc",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
)
|
||||
assert data.external_identity is None
|
||||
|
||||
|
||||
def test_auth_data_tenants_default_empty():
|
||||
data = AuthData(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
token_hash="abc",
|
||||
scopes=frozenset(),
|
||||
)
|
||||
assert data.tenants == {}
|
||||
|
||||
|
||||
def test_auth_data_token_id_optional():
|
||||
data = AuthData(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
token_hash="abc",
|
||||
scopes=frozenset(),
|
||||
)
|
||||
assert data.token_id is None
|
||||
@ -1,42 +0,0 @@
|
||||
import inspect
|
||||
|
||||
from controllers.openapi.auth.conditions import Cond
|
||||
from controllers.openapi.auth.data import AuthData, RequestContext
|
||||
from controllers.openapi.auth.flow import When
|
||||
from libs.oauth_bearer import TokenType
|
||||
|
||||
|
||||
def _ctx():
|
||||
return RequestContext(token_type=TokenType.OAUTH_ACCOUNT, path_params={})
|
||||
|
||||
|
||||
def _data():
|
||||
return AuthData(token_type=TokenType.OAUTH_ACCOUNT, token_hash="x", scopes=frozenset())
|
||||
|
||||
|
||||
def test_applies_returns_true_when_condition_true():
|
||||
w = When(Cond(lambda ctx, _: True), then=lambda b: None)
|
||||
assert w.applies(_ctx()) is True
|
||||
|
||||
|
||||
def test_applies_returns_false_when_condition_false():
|
||||
w = When(Cond(lambda ctx, _: False), then=lambda b: None)
|
||||
assert w.applies(_ctx()) is False
|
||||
|
||||
|
||||
def test_applies_with_data():
|
||||
w = When(Cond(lambda ctx, data: data is not None), then=lambda b: None)
|
||||
assert w.applies(_ctx(), _data()) is True
|
||||
assert w.applies(_ctx(), None) is False
|
||||
|
||||
|
||||
def test_call_invokes_step():
|
||||
calls = []
|
||||
w = When(Cond(lambda ctx, _: True), then=lambda arg: calls.append(arg))
|
||||
w("payload")
|
||||
assert calls == ["payload"]
|
||||
|
||||
|
||||
def test_then_is_keyword_only():
|
||||
sig = inspect.signature(When.__init__)
|
||||
assert sig.parameters["then"].kind.name == "KEYWORD_ONLY"
|
||||
@ -1,204 +1,59 @@
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.data import AuthData, Edition
|
||||
from controllers.openapi.auth.pipeline import AuthPipeline, PipelineRoute, PipelineRouter
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
|
||||
|
||||
def _make_identity(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
account_id=None,
|
||||
scopes=None,
|
||||
token_hash="testhash",
|
||||
subject_email=None,
|
||||
subject_issuer=None,
|
||||
verified_tenants=None,
|
||||
token_id=None,
|
||||
):
|
||||
identity = MagicMock()
|
||||
identity.token_type = token_type
|
||||
identity.account_id = account_id or uuid.uuid4()
|
||||
identity.scopes = scopes or frozenset({Scope.FULL})
|
||||
identity.token_hash = token_hash
|
||||
identity.subject_email = subject_email
|
||||
identity.subject_issuer = subject_issuer
|
||||
identity.verified_tenants = verified_tenants or {}
|
||||
identity.token_id = token_id or uuid.uuid4()
|
||||
return identity
|
||||
def test_run_invokes_each_step_in_order():
|
||||
calls = []
|
||||
|
||||
class S:
|
||||
def __init__(self, tag):
|
||||
self.tag = tag
|
||||
|
||||
def __call__(self, ctx):
|
||||
calls.append(self.tag)
|
||||
|
||||
Pipeline(S("a"), S("b"), S("c")).run(Context(required_scope="x"))
|
||||
assert calls == ["a", "b", "c"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
return Flask(__name__)
|
||||
def test_run_short_circuits_on_raise():
|
||||
calls = []
|
||||
|
||||
class Boom:
|
||||
def __call__(self, ctx):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
class Tail:
|
||||
def __call__(self, ctx):
|
||||
calls.append("ran")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
Pipeline(Boom(), Tail()).run(Context(required_scope="x"))
|
||||
assert calls == []
|
||||
|
||||
|
||||
def _make_router(token_type=TokenType.OAUTH_ACCOUNT, prepare=None, auth=None):
|
||||
pipeline = AuthPipeline(prepare=prepare or [], auth=auth or [])
|
||||
return PipelineRouter({token_type: PipelineRoute(pipeline)})
|
||||
def test_guard_decorator_runs_pipeline_and_unpacks_handler_kwargs():
|
||||
seen = {}
|
||||
|
||||
class FakeStep:
|
||||
def __call__(self, ctx):
|
||||
ctx.app = "APP"
|
||||
ctx.caller = "CALLER"
|
||||
ctx.caller_kind = "account"
|
||||
|
||||
def _fake_identity():
|
||||
return _make_identity()
|
||||
pipeline = Pipeline(FakeStep())
|
||||
|
||||
@pipeline.guard(scope="apps:run")
|
||||
def handler(app_model, caller, caller_kind):
|
||||
seen["app_model"] = app_model
|
||||
seen["caller"] = caller
|
||||
seen["caller_kind"] = caller_kind
|
||||
return "ok"
|
||||
|
||||
# --- PipelineRouter.guard ---
|
||||
|
||||
def test_guard_passes_auth_data_to_view(app):
|
||||
router = _make_router()
|
||||
received = {}
|
||||
|
||||
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
|
||||
with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), \
|
||||
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, \
|
||||
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()), \
|
||||
patch("controllers.openapi.auth.pipeline.reset_auth_ctx"):
|
||||
mock_auth.return_value.authenticate.return_value = _fake_identity()
|
||||
|
||||
@router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def view(*, auth_data):
|
||||
received["data"] = auth_data
|
||||
|
||||
view()
|
||||
|
||||
assert isinstance(received["data"], AuthData)
|
||||
|
||||
|
||||
def test_guard_edition_gate_returns_404(app):
|
||||
router = _make_router()
|
||||
|
||||
with app.test_request_context("/test"):
|
||||
with patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE):
|
||||
|
||||
@router.guard(scope=Scope.FULL, edition=frozenset({Edition.EE}))
|
||||
def view(*, auth_data):
|
||||
pass
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
view()
|
||||
|
||||
|
||||
def test_guard_token_type_gate_returns_403(app):
|
||||
router = _make_router()
|
||||
|
||||
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
|
||||
with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), \
|
||||
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, \
|
||||
patch("controllers.openapi.auth.pipeline.emit_wrong_surface"), \
|
||||
patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE):
|
||||
identity = _fake_identity()
|
||||
identity.token_type = TokenType.OAUTH_EXTERNAL_SSO
|
||||
mock_auth.return_value.authenticate.return_value = identity
|
||||
|
||||
@router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def view(*, auth_data):
|
||||
pass
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
view()
|
||||
|
||||
|
||||
def test_guard_unregistered_token_type_returns_403(app):
|
||||
# Router has only OAUTH_ACCOUNT; present OAUTH_EXTERNAL_SSO without allowed_token_types gate
|
||||
router = _make_router(token_type=TokenType.OAUTH_ACCOUNT)
|
||||
|
||||
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
|
||||
with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), \
|
||||
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, \
|
||||
patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE):
|
||||
identity = _fake_identity()
|
||||
identity.token_type = TokenType.OAUTH_EXTERNAL_SSO
|
||||
mock_auth.return_value.authenticate.return_value = identity
|
||||
|
||||
@router.guard(scope=Scope.FULL)
|
||||
def view(*, auth_data):
|
||||
pass
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
view()
|
||||
|
||||
|
||||
def test_guard_no_bearer_returns_401(app):
|
||||
router = _make_router()
|
||||
|
||||
with app.test_request_context("/test"):
|
||||
with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value=None):
|
||||
|
||||
@router.guard(scope=Scope.FULL)
|
||||
def view(*, auth_data):
|
||||
pass
|
||||
|
||||
with pytest.raises(Unauthorized):
|
||||
view()
|
||||
|
||||
|
||||
def test_guard_runs_prepare_steps_in_order(app):
|
||||
order = []
|
||||
|
||||
def p1(b):
|
||||
order.append("p1")
|
||||
|
||||
def p2(b):
|
||||
order.append("p2")
|
||||
|
||||
router = _make_router(prepare=[p1, p2])
|
||||
|
||||
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
|
||||
with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), \
|
||||
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, \
|
||||
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()), \
|
||||
patch("controllers.openapi.auth.pipeline.reset_auth_ctx"):
|
||||
mock_auth.return_value.authenticate.return_value = _fake_identity()
|
||||
|
||||
@router.guard(scope=Scope.FULL)
|
||||
def view(*, auth_data):
|
||||
pass
|
||||
|
||||
view()
|
||||
|
||||
assert order == ["p1", "p2"]
|
||||
|
||||
|
||||
def test_guard_resets_auth_ctx_on_exception(app):
|
||||
router = _make_router()
|
||||
reset_called = []
|
||||
|
||||
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
|
||||
with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), \
|
||||
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, \
|
||||
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value="tok"), \
|
||||
patch("controllers.openapi.auth.pipeline.reset_auth_ctx", side_effect=lambda t: reset_called.append(t)):
|
||||
mock_auth.return_value.authenticate.return_value = _fake_identity()
|
||||
|
||||
@router.guard(scope=Scope.FULL)
|
||||
def view(*, auth_data):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
view()
|
||||
|
||||
assert reset_called == ["tok"]
|
||||
|
||||
|
||||
def test_router_rejects_token_type_on_wrong_edition(app):
|
||||
pipeline = AuthPipeline(prepare=[], auth=[])
|
||||
route = PipelineRoute(pipeline, required_edition=frozenset({Edition.EE}))
|
||||
router = PipelineRouter({TokenType.OAUTH_EXTERNAL_SSO: route})
|
||||
|
||||
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
|
||||
with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), \
|
||||
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, \
|
||||
patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE):
|
||||
identity = _make_identity(token_type=TokenType.OAUTH_EXTERNAL_SSO)
|
||||
mock_auth.return_value.authenticate.return_value = identity
|
||||
|
||||
@router.guard(scope=Scope.APPS_RUN)
|
||||
def view(*, auth_data):
|
||||
pass
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
view()
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/x", method="POST"):
|
||||
assert handler() == "ok"
|
||||
assert seen == {"app_model": "APP", "caller": "CALLER", "caller_kind": "account"}
|
||||
|
||||
@ -1,187 +0,0 @@
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.data import ExternalIdentity
|
||||
from controllers.openapi.auth.prepare import (
|
||||
build_external_identity,
|
||||
load_account,
|
||||
load_app,
|
||||
load_app_access_mode,
|
||||
load_tenant,
|
||||
resolve_external_user,
|
||||
)
|
||||
|
||||
# --- load_app ---
|
||||
|
||||
|
||||
def test_load_app_writes_app_to_builder():
|
||||
app = MagicMock()
|
||||
app.status = "normal"
|
||||
app.enable_api = True
|
||||
builder = {"path_params": {"app_id": "abc"}}
|
||||
with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=app):
|
||||
load_app(builder)
|
||||
assert builder["app"] is app
|
||||
|
||||
|
||||
def test_load_app_raises_not_found_when_missing():
|
||||
builder = {"path_params": {"app_id": "missing"}}
|
||||
with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=None):
|
||||
with pytest.raises(NotFound):
|
||||
load_app(builder)
|
||||
|
||||
|
||||
def test_load_app_raises_not_found_when_not_normal():
|
||||
app = MagicMock()
|
||||
app.status = "archived"
|
||||
builder = {"path_params": {"app_id": "abc"}}
|
||||
with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=app):
|
||||
with pytest.raises(NotFound):
|
||||
load_app(builder)
|
||||
|
||||
|
||||
def test_load_app_raises_forbidden_when_api_disabled():
|
||||
app = MagicMock()
|
||||
app.status = "normal"
|
||||
app.enable_api = False
|
||||
builder = {"path_params": {"app_id": "abc"}}
|
||||
with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=app):
|
||||
with pytest.raises(Forbidden):
|
||||
load_app(builder)
|
||||
|
||||
|
||||
# --- load_tenant ---
|
||||
|
||||
def test_load_tenant_writes_tenant():
|
||||
app = MagicMock()
|
||||
app.tenant_id = uuid.uuid4()
|
||||
tenant = MagicMock()
|
||||
tenant.status = "normal"
|
||||
builder = {"app": app}
|
||||
with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=tenant):
|
||||
load_tenant(builder)
|
||||
assert builder["tenant"] is tenant
|
||||
|
||||
|
||||
def test_load_tenant_raises_forbidden_when_archived():
|
||||
from models.account import TenantStatus
|
||||
app = MagicMock()
|
||||
app.tenant_id = uuid.uuid4()
|
||||
tenant = MagicMock()
|
||||
tenant.status = TenantStatus.ARCHIVE
|
||||
builder = {"app": app}
|
||||
with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=tenant):
|
||||
with pytest.raises(Forbidden):
|
||||
load_tenant(builder)
|
||||
|
||||
|
||||
def test_load_tenant_raises_forbidden_when_missing():
|
||||
app = MagicMock()
|
||||
app.tenant_id = uuid.uuid4()
|
||||
builder = {"app": app}
|
||||
with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=None):
|
||||
with pytest.raises(Forbidden):
|
||||
load_tenant(builder)
|
||||
|
||||
|
||||
# --- load_account ---
|
||||
|
||||
def test_load_account_writes_caller():
|
||||
account = MagicMock()
|
||||
account_id = uuid.uuid4()
|
||||
builder = {"account_id": account_id}
|
||||
with patch("controllers.openapi.auth.prepare.AccountService.get_account_by_id", return_value=account):
|
||||
load_account(builder)
|
||||
assert builder["caller"] is account
|
||||
assert builder["caller_kind"] == "account"
|
||||
|
||||
|
||||
def test_load_account_sets_current_tenant_when_tenant_present():
|
||||
account = MagicMock()
|
||||
tenant = MagicMock()
|
||||
builder = {"account_id": uuid.uuid4(), "tenant": tenant}
|
||||
with patch("controllers.openapi.auth.prepare.AccountService.get_account_by_id", return_value=account):
|
||||
load_account(builder)
|
||||
assert account.current_tenant is tenant
|
||||
|
||||
|
||||
def test_load_account_raises_unauthorized_when_not_found():
|
||||
builder = {"account_id": uuid.uuid4()}
|
||||
with patch("controllers.openapi.auth.prepare.AccountService.get_account_by_id", return_value=None):
|
||||
with pytest.raises(Unauthorized):
|
||||
load_account(builder)
|
||||
|
||||
|
||||
# --- resolve_external_user ---
|
||||
|
||||
def test_resolve_external_user_writes_caller():
|
||||
tenant = MagicMock()
|
||||
app = MagicMock()
|
||||
end_user = MagicMock()
|
||||
ext = ExternalIdentity(email="user@sso.com")
|
||||
builder = {"tenant": tenant, "app": app, "external_identity": ext}
|
||||
with patch("controllers.openapi.auth.prepare.EndUserService.get_or_create_end_user_by_type", return_value=end_user):
|
||||
resolve_external_user(builder)
|
||||
assert builder["caller"] is end_user
|
||||
assert builder["caller_kind"] == "end_user"
|
||||
|
||||
|
||||
def test_resolve_external_user_raises_unauthorized_when_context_missing():
|
||||
builder = {"tenant": None, "app": MagicMock(), "external_identity": ExternalIdentity(email="u@s.com")}
|
||||
with pytest.raises(Unauthorized):
|
||||
resolve_external_user(builder)
|
||||
|
||||
|
||||
# --- load_app_access_mode ---
|
||||
|
||||
def test_load_app_access_mode_writes_mode():
|
||||
from services.enterprise.enterprise_service import WebAppAccessMode
|
||||
app = MagicMock()
|
||||
app.id = "app-1"
|
||||
settings = MagicMock()
|
||||
settings.access_mode = "public"
|
||||
builder = {"app": app}
|
||||
with patch(
|
||||
"controllers.openapi.auth.prepare.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
|
||||
return_value=settings,
|
||||
):
|
||||
load_app_access_mode(builder)
|
||||
assert builder["app_access_mode"] == WebAppAccessMode.PUBLIC
|
||||
|
||||
|
||||
def test_load_app_access_mode_writes_none_when_value_error():
|
||||
app = MagicMock()
|
||||
app.id = "app-1"
|
||||
builder = {"app": app}
|
||||
with patch(
|
||||
"controllers.openapi.auth.prepare.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
|
||||
side_effect=ValueError("No data found."),
|
||||
):
|
||||
load_app_access_mode(builder)
|
||||
assert builder["app_access_mode"] is None
|
||||
|
||||
|
||||
def test_load_app_access_mode_no_op_when_app_missing():
|
||||
builder = {}
|
||||
load_app_access_mode(builder)
|
||||
assert "app_access_mode" not in builder
|
||||
|
||||
|
||||
# --- build_external_identity ---
|
||||
|
||||
def test_build_external_identity_constructs_from_builder_keys():
|
||||
from controllers.openapi.auth.data import ExternalIdentity
|
||||
builder = {"_subject_email": "u@sso.com", "_subject_issuer": "idp"}
|
||||
build_external_identity(builder)
|
||||
assert isinstance(builder["external_identity"], ExternalIdentity)
|
||||
assert builder["external_identity"].email == "u@sso.com"
|
||||
assert "_subject_email" not in builder
|
||||
|
||||
|
||||
def test_build_external_identity_no_op_when_email_missing():
|
||||
builder = {"_subject_email": None, "_subject_issuer": None}
|
||||
build_external_identity(builder)
|
||||
assert "external_identity" not in builder
|
||||
@ -0,0 +1,64 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import AppResolver
|
||||
from models import TenantStatus
|
||||
|
||||
|
||||
def _ctx(path_params: dict[str, str] | None) -> Context:
|
||||
return Context(required_scope="apps:run", path_params=path_params or {})
|
||||
|
||||
|
||||
def _app(*, status="normal", enable_api=True):
|
||||
return SimpleNamespace(id="app1", tenant_id="t1", status=status, enable_api=enable_api)
|
||||
|
||||
|
||||
def _tenant(*, status=TenantStatus.NORMAL):
|
||||
return SimpleNamespace(id="t1", status=status)
|
||||
|
||||
|
||||
def test_resolver_rejects_missing_path_param():
|
||||
with pytest.raises(BadRequest):
|
||||
AppResolver()(_ctx({}))
|
||||
|
||||
|
||||
def test_resolver_rejects_empty_path_params():
|
||||
# `Pipeline.guard` always seeds an empty dict when Flask reports no
|
||||
# view args, so a missing `app_id` key surfaces here as BadRequest.
|
||||
with pytest.raises(BadRequest):
|
||||
AppResolver()(_ctx(None))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.db")
|
||||
def test_resolver_404_when_app_missing(db):
|
||||
db.session.get.side_effect = [None]
|
||||
with pytest.raises(NotFound):
|
||||
AppResolver()(_ctx({"app_id": "x"}))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.db")
|
||||
def test_resolver_403_when_disabled(db):
|
||||
db.session.get.side_effect = [_app(enable_api=False)]
|
||||
with pytest.raises(Forbidden) as exc:
|
||||
AppResolver()(_ctx({"app_id": "x"}))
|
||||
assert "service_api_disabled" in str(exc.value.description)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.db")
|
||||
def test_resolver_403_when_tenant_archived(db):
|
||||
db.session.get.side_effect = [_app(), _tenant(status=TenantStatus.ARCHIVE)]
|
||||
with pytest.raises(Forbidden):
|
||||
AppResolver()(_ctx({"app_id": "x"}))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.db")
|
||||
def test_resolver_populates_app_and_tenant(db):
|
||||
db.session.get.side_effect = [_app(), _tenant()]
|
||||
ctx = _ctx({"app_id": "x"})
|
||||
AppResolver()(ctx)
|
||||
assert ctx.app.id == "app1"
|
||||
assert ctx.tenant.id == "t1"
|
||||
@ -0,0 +1,76 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import AppAuthzCheck
|
||||
from controllers.openapi.auth.strategies import AclStrategy, MembershipStrategy
|
||||
from libs.oauth_bearer import SubjectType
|
||||
|
||||
|
||||
def _ctx(*, subject_type, account_id="acc1"):
|
||||
c = Context(required_scope="apps:run")
|
||||
c.subject_type = subject_type
|
||||
c.subject_email = "alice@example.com"
|
||||
c.account_id = account_id
|
||||
c.app = SimpleNamespace(id="app1")
|
||||
c.tenant = SimpleNamespace(id="t1")
|
||||
return c
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.strategies.EnterpriseService")
|
||||
def test_acl_strategy_private_calls_inner_api(ent):
|
||||
ent.WebAppAuth.get_app_access_mode_by_id.return_value = SimpleNamespace(access_mode="private")
|
||||
ent.WebAppAuth.is_user_allowed_to_access_webapp.return_value = True
|
||||
assert AclStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True
|
||||
ent.WebAppAuth.is_user_allowed_to_access_webapp.assert_called_once_with(
|
||||
user_id="acc1",
|
||||
app_id="app1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("access_mode", "subject_type", "expected"),
|
||||
[
|
||||
("public", SubjectType.ACCOUNT, True),
|
||||
("public", SubjectType.EXTERNAL_SSO, True),
|
||||
("sso_verified", SubjectType.ACCOUNT, True),
|
||||
("sso_verified", SubjectType.EXTERNAL_SSO, True),
|
||||
("private_all", SubjectType.ACCOUNT, True),
|
||||
("private_all", SubjectType.EXTERNAL_SSO, False),
|
||||
("private", SubjectType.EXTERNAL_SSO, False),
|
||||
],
|
||||
)
|
||||
@patch("controllers.openapi.auth.strategies.EnterpriseService")
|
||||
def test_acl_strategy_subject_mode_matrix(ent, access_mode, subject_type, expected):
|
||||
"""Step 1 matrix: subject vs access-mode compatibility. No inner API call expected."""
|
||||
ent.WebAppAuth.get_app_access_mode_by_id.return_value = SimpleNamespace(access_mode=access_mode)
|
||||
account_id = "acc1" if subject_type == SubjectType.ACCOUNT else None
|
||||
assert AclStrategy().authorize(_ctx(subject_type=subject_type, account_id=account_id)) is expected
|
||||
ent.WebAppAuth.is_user_allowed_to_access_webapp.assert_not_called()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.strategies.TenantService.account_belongs_to_tenant")
|
||||
@patch("controllers.openapi.auth.strategies.db")
|
||||
def test_membership_strategy_uses_join_lookup(db_mock, member):
|
||||
member.return_value = True
|
||||
assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True
|
||||
member.assert_called_once_with(db_mock.session, "acc1", "t1")
|
||||
|
||||
|
||||
def test_membership_strategy_rejects_external_sso():
|
||||
assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.EXTERNAL_SSO, account_id=None)) is False
|
||||
|
||||
|
||||
def test_app_authz_check_raises_when_strategy_denies():
|
||||
deny = SimpleNamespace(authorize=lambda c: False)
|
||||
with pytest.raises(Forbidden) as exc:
|
||||
AppAuthzCheck(lambda: deny)(_ctx(subject_type=SubjectType.ACCOUNT))
|
||||
assert "subject_no_app_access" in str(exc.value.description)
|
||||
|
||||
|
||||
def test_app_authz_check_passes_when_strategy_allows():
|
||||
allow = SimpleNamespace(authorize=lambda c: True)
|
||||
AppAuthzCheck(lambda: allow)(_ctx(subject_type=SubjectType.ACCOUNT))
|
||||
@ -0,0 +1,83 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import BearerCheck
|
||||
from libs.oauth_bearer import (
|
||||
AuthContext,
|
||||
InvalidBearerError,
|
||||
Scope,
|
||||
SubjectType,
|
||||
reset_auth_ctx,
|
||||
try_get_auth_ctx,
|
||||
)
|
||||
|
||||
|
||||
def _ctx(bearer_token: str | None) -> Context:
|
||||
return Context(required_scope="apps:run", bearer_token=bearer_token)
|
||||
|
||||
|
||||
def test_bearer_check_rejects_missing_header():
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context(), pytest.raises(Unauthorized):
|
||||
BearerCheck()(_ctx(None))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.get_authenticator")
|
||||
def test_bearer_check_rejects_unknown_prefix(get_auth):
|
||||
get_auth.return_value.authenticate.side_effect = InvalidBearerError("invalid_bearer")
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context(), pytest.raises(Unauthorized):
|
||||
BearerCheck()(_ctx("xxx_abc"))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.get_authenticator")
|
||||
def test_bearer_check_populates_context_and_publishes_auth_ctx(get_auth):
|
||||
tok_id = uuid.uuid4()
|
||||
authn = AuthContext(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
subject_email="a@x.com",
|
||||
subject_issuer=None,
|
||||
account_id=None,
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
token_id=tok_id,
|
||||
source="oauth-account",
|
||||
expires_at=datetime.now(UTC),
|
||||
token_hash="hash-1",
|
||||
verified_tenants={},
|
||||
)
|
||||
get_auth.return_value.authenticate.return_value = authn
|
||||
|
||||
app = Flask(__name__)
|
||||
ctx = _ctx("dfoa_abc")
|
||||
with app.test_request_context():
|
||||
BearerCheck()(ctx)
|
||||
try:
|
||||
assert ctx.subject_type == SubjectType.ACCOUNT
|
||||
assert ctx.subject_email == "a@x.com"
|
||||
assert ctx.scopes == frozenset({Scope.FULL})
|
||||
assert ctx.source == "oauth-account"
|
||||
assert ctx.token_id == tok_id
|
||||
assert ctx.token_hash == "hash-1"
|
||||
# BearerCheck must also publish the same identity on the
|
||||
# openapi auth ContextVar so the surface gate + downstream
|
||||
# handlers don't see two different identity sources between
|
||||
# the decorator + pipeline paths. The reset token is parked
|
||||
# on `ctx.auth_ctx_reset_token` for `Pipeline.guard` to
|
||||
# consume in its `finally`.
|
||||
published = try_get_auth_ctx()
|
||||
assert published is authn
|
||||
assert published.client_id == "difyctl"
|
||||
assert ctx.auth_ctx_reset_token is not None
|
||||
finally:
|
||||
# In production `Pipeline.guard` resets the ContextVar; in
|
||||
# this isolated step-level test we reset it ourselves so the
|
||||
# value doesn't leak into the next test on the same worker.
|
||||
assert ctx.auth_ctx_reset_token is not None
|
||||
reset_auth_ctx(ctx.auth_ctx_reset_token)
|
||||
@ -0,0 +1,157 @@
|
||||
"""Unit tests for WorkspaceMembershipCheck (Layer 0)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import WorkspaceMembershipCheck
|
||||
from libs.oauth_bearer import SubjectType
|
||||
|
||||
|
||||
def _ctx(*, subject_type, account_id, tenant_id, cached_verified_tenants=None, token_hash=None) -> Context:
|
||||
c = Context(required_scope="apps:read")
|
||||
c.subject_type = subject_type
|
||||
c.account_id = account_id
|
||||
c.tenant = SimpleNamespace(id=tenant_id) if tenant_id else None
|
||||
c.cached_verified_tenants = cached_verified_tenants
|
||||
c.token_hash = token_hash
|
||||
return c
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def step():
|
||||
return WorkspaceMembershipCheck()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_skips_when_enterprise_enabled(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = True
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id=str(uuid.uuid4()),
|
||||
tenant_id=str(uuid.uuid4()),
|
||||
cached_verified_tenants={},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
step(ctx) # no raise
|
||||
mock_db.session.execute.assert_not_called()
|
||||
mock_record.assert_not_called()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_skips_for_external_sso(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.EXTERNAL_SSO,
|
||||
account_id=None,
|
||||
tenant_id=str(uuid.uuid4()),
|
||||
cached_verified_tenants={},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
step(ctx) # no raise
|
||||
mock_db.session.execute.assert_not_called()
|
||||
mock_record.assert_not_called()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_uses_cached_ok(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id="a1",
|
||||
tenant_id="t1",
|
||||
cached_verified_tenants={"t1": True},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
step(ctx)
|
||||
mock_db.session.execute.assert_not_called()
|
||||
mock_record.assert_not_called()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_uses_cached_denied(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id="a1",
|
||||
tenant_id="t1",
|
||||
cached_verified_tenants={"t1": False},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
|
||||
step(ctx)
|
||||
mock_db.session.execute.assert_not_called()
|
||||
mock_record.assert_not_called()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_denies_when_no_membership(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
mock_db.session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id="a1",
|
||||
tenant_id="t1",
|
||||
cached_verified_tenants={},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
|
||||
step(ctx)
|
||||
mock_record.assert_called_once_with("hash-1", "t1", False)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_denies_when_account_inactive(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
mock_db.session.execute.side_effect = [
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")),
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value="banned")),
|
||||
]
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id="a1",
|
||||
tenant_id="t1",
|
||||
cached_verified_tenants={},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
|
||||
step(ctx)
|
||||
mock_record.assert_called_once_with("hash-1", "t1", False)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_allows_active_member(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
mock_db.session.execute.side_effect = [
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")),
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value="active")),
|
||||
]
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id="a1",
|
||||
tenant_id="t1",
|
||||
cached_verified_tenants={},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
step(ctx) # no raise
|
||||
mock_record.assert_called_once_with("hash-1", "t1", True)
|
||||
@ -0,0 +1,77 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import CallerMount
|
||||
from controllers.openapi.auth.strategies import AccountMounter, EndUserMounter
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from libs.oauth_bearer import SubjectType
|
||||
|
||||
|
||||
def _ctx(*, subject_type, account_id=None, subject_email=None):
|
||||
c = Context(required_scope="apps:run")
|
||||
c.subject_type = subject_type
|
||||
c.account_id = account_id
|
||||
c.subject_email = subject_email
|
||||
c.app = SimpleNamespace(id="app1")
|
||||
c.tenant = SimpleNamespace(id="t1")
|
||||
return c
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.strategies._login_as")
|
||||
@patch("controllers.openapi.auth.strategies.db")
|
||||
def test_account_mounter(db, login):
|
||||
account = SimpleNamespace()
|
||||
db.session.get.return_value = account
|
||||
ctx = _ctx(subject_type=SubjectType.ACCOUNT, account_id="acc1")
|
||||
AccountMounter().mount(ctx)
|
||||
assert ctx.caller is account
|
||||
assert ctx.caller.current_tenant is ctx.tenant
|
||||
assert ctx.caller_kind == "account"
|
||||
login.assert_called_once_with(account)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.strategies._login_as")
|
||||
@patch("controllers.openapi.auth.strategies.EndUserService")
|
||||
def test_end_user_mounter(svc, login):
|
||||
eu = SimpleNamespace()
|
||||
svc.get_or_create_end_user_by_type.return_value = eu
|
||||
ctx = _ctx(subject_type=SubjectType.EXTERNAL_SSO, subject_email="a@x.com")
|
||||
EndUserMounter().mount(ctx)
|
||||
svc.get_or_create_end_user_by_type.assert_called_once_with(
|
||||
InvokeFrom.OPENAPI,
|
||||
tenant_id="t1",
|
||||
app_id="app1",
|
||||
user_id="a@x.com",
|
||||
)
|
||||
assert ctx.caller is eu
|
||||
assert ctx.caller_kind == "end_user"
|
||||
|
||||
|
||||
def test_caller_mount_dispatches_by_subject_type():
|
||||
seen = {}
|
||||
|
||||
class Fake:
|
||||
def __init__(self, st, tag):
|
||||
self._st, self._tag = st, tag
|
||||
|
||||
def applies_to(self, st):
|
||||
return st == self._st
|
||||
|
||||
def mount(self, ctx):
|
||||
seen["who"] = self._tag
|
||||
|
||||
cm = CallerMount(
|
||||
Fake(SubjectType.ACCOUNT, "acct"),
|
||||
Fake(SubjectType.EXTERNAL_SSO, "sso"),
|
||||
)
|
||||
cm(_ctx(subject_type=SubjectType.EXTERNAL_SSO))
|
||||
assert seen == {"who": "sso"}
|
||||
|
||||
|
||||
def test_caller_mount_raises_when_none_applies():
|
||||
with pytest.raises(Unauthorized):
|
||||
CallerMount()(_ctx(subject_type=SubjectType.ACCOUNT))
|
||||
@ -0,0 +1,25 @@
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import ScopeCheck
|
||||
|
||||
|
||||
def _ctx(scopes, required):
|
||||
c = Context(required_scope=required)
|
||||
c.scopes = frozenset(scopes)
|
||||
return c
|
||||
|
||||
|
||||
def test_scope_check_passes_on_full():
|
||||
ScopeCheck()(_ctx({"full"}, "apps:run"))
|
||||
|
||||
|
||||
def test_scope_check_passes_on_explicit_match():
|
||||
ScopeCheck()(_ctx({"apps:run"}, "apps:run"))
|
||||
|
||||
|
||||
def test_scope_check_rejects_when_missing():
|
||||
with pytest.raises(Forbidden) as exc:
|
||||
ScopeCheck()(_ctx({"apps:read"}, "apps:run"))
|
||||
assert "insufficient_scope" in str(exc.value.description)
|
||||
@ -0,0 +1,239 @@
|
||||
"""Surface gate tests.
|
||||
|
||||
The gate has two attachment forms — decorator (`accept_subjects`) and
|
||||
pipeline step (`SurfaceCheck`) — and both must:
|
||||
- 403 on mismatched subject type with a canonical-path hint
|
||||
- emit `openapi.wrong_surface_denied` once with the right payload
|
||||
- pass-through on match
|
||||
- raise RuntimeError (not 403) if the auth ContextVar is unset — that's
|
||||
a wiring bug, not a user-driven failure
|
||||
|
||||
Identity is published via `libs.oauth_bearer.set_auth_ctx` / read with
|
||||
`try_get_auth_ctx`. Tests wrap the publish in a `_publish_auth_ctx`
|
||||
context manager so the ContextVar resets even when an assertion fails;
|
||||
that keeps state from leaking into the next test on the same worker.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import SurfaceCheck
|
||||
from controllers.openapi.auth.surface_gate import _coerce_subject_type, accept_subjects, check_surface
|
||||
from libs.oauth_bearer import AuthContext, Scope, SubjectType, reset_auth_ctx, set_auth_ctx
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _publish_auth_ctx(ctx: AuthContext) -> Iterator[None]:
|
||||
token = set_auth_ctx(ctx)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
reset_auth_ctx(token)
|
||||
|
||||
|
||||
def _account_ctx() -> AuthContext:
|
||||
return AuthContext(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
subject_email="user@example.com",
|
||||
subject_issuer="dify:account",
|
||||
account_id=uuid.uuid4(),
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
token_id=uuid.uuid4(),
|
||||
source="oauth_account",
|
||||
expires_at=datetime.now(UTC),
|
||||
token_hash="h1",
|
||||
verified_tenants={},
|
||||
)
|
||||
|
||||
|
||||
def _sso_ctx() -> AuthContext:
|
||||
return AuthContext(
|
||||
subject_type=SubjectType.EXTERNAL_SSO,
|
||||
subject_email="sso@partner.com",
|
||||
subject_issuer="https://idp.partner.com",
|
||||
account_id=None,
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({Scope.APPS_RUN, Scope.APPS_READ_PERMITTED_EXTERNAL}),
|
||||
token_id=uuid.uuid4(),
|
||||
source="oauth_external_sso",
|
||||
expires_at=datetime.now(UTC),
|
||||
token_hash="h2",
|
||||
verified_tenants={},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_surface — shared core
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_check_surface_passes_when_subject_in_accepted():
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/apps"), _publish_auth_ctx(_account_ctx()):
|
||||
check_surface(frozenset({SubjectType.ACCOUNT})) # no raise
|
||||
|
||||
|
||||
def test_check_surface_rejects_on_wrong_subject_and_emits_audit():
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/permitted-external-apps"), _publish_auth_ctx(_account_ctx()):
|
||||
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
|
||||
with pytest.raises(Forbidden) as exc:
|
||||
check_surface(frozenset({SubjectType.EXTERNAL_SSO}))
|
||||
assert "wrong_surface" in exc.value.description
|
||||
# canonical-path hint should point at the caller's surface,
|
||||
# not the surface they were rejected from
|
||||
assert "/openapi/v1/apps" in exc.value.description
|
||||
emit.assert_called_once()
|
||||
kwargs = emit.call_args.kwargs
|
||||
assert kwargs["subject_type"] == SubjectType.ACCOUNT.value
|
||||
assert kwargs["attempted_path"] == "/openapi/v1/permitted-external-apps"
|
||||
assert kwargs["client_id"] == "difyctl"
|
||||
assert kwargs["token_id"] is not None
|
||||
|
||||
|
||||
def test_check_surface_rejects_sso_on_account_surface():
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/apps"), _publish_auth_ctx(_sso_ctx()):
|
||||
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
|
||||
with pytest.raises(Forbidden):
|
||||
check_surface(frozenset({SubjectType.ACCOUNT}))
|
||||
kwargs = emit.call_args.kwargs
|
||||
assert kwargs["subject_type"] == SubjectType.EXTERNAL_SSO.value
|
||||
|
||||
|
||||
def test_check_surface_runtime_error_when_auth_ctx_missing():
|
||||
"""Missing auth ContextVar means the bearer layer didn't run — wiring
|
||||
bug, not a user-driven failure. Surface as RuntimeError (loud) so a
|
||||
future refactor doesn't accidentally let a route skip authentication
|
||||
and return a 403 that looks identical to a legitimate wrong-surface
|
||||
deny.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/apps"):
|
||||
with pytest.raises(RuntimeError):
|
||||
check_surface(frozenset({SubjectType.ACCOUNT}))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# @accept_subjects — decorator form
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route("/account-only")
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
def _account_only():
|
||||
return "ok"
|
||||
|
||||
@app.route("/external-only")
|
||||
@accept_subjects(SubjectType.EXTERNAL_SSO)
|
||||
def _external_only():
|
||||
return "ok"
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def test_accept_subjects_decorator_passes_on_match():
|
||||
app = _make_app()
|
||||
with app.test_request_context("/account-only"), _publish_auth_ctx(_account_ctx()):
|
||||
# Re-route through the decorated function by reaching for view_function
|
||||
view = app.view_functions["_account_only"]
|
||||
assert view() == "ok"
|
||||
|
||||
|
||||
def test_accept_subjects_decorator_403_on_miss():
|
||||
app = _make_app()
|
||||
with app.test_request_context("/external-only"), _publish_auth_ctx(_account_ctx()):
|
||||
view = app.view_functions["_external_only"]
|
||||
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface"):
|
||||
with pytest.raises(Forbidden):
|
||||
view()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SurfaceCheck — pipeline step form
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _pipeline_ctx() -> Context:
|
||||
# SurfaceCheck reads ``request.path`` from Flask's global request — set up
|
||||
# via ``app.test_request_context`` in the calling tests — not from Context.
|
||||
return Context(required_scope=Scope.APPS_RUN)
|
||||
|
||||
|
||||
def test_surface_check_passes_on_match():
|
||||
step = SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT}))
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/apps/x/run"), _publish_auth_ctx(_account_ctx()):
|
||||
step(_pipeline_ctx()) # no raise
|
||||
|
||||
|
||||
def test_surface_check_rejects_on_miss_and_emits_audit():
|
||||
step = SurfaceCheck(accepted=frozenset({SubjectType.EXTERNAL_SSO}))
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/apps/x/run"), _publish_auth_ctx(_account_ctx()):
|
||||
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
|
||||
with pytest.raises(Forbidden):
|
||||
step(_pipeline_ctx())
|
||||
emit.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _coerce_subject_type — normalises whatever sat on ctx.subject_type
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# The gate reads `ctx.subject_type` via `getattr(..., None)`, so the value
|
||||
# could be a real enum (happy path), a raw string (e.g. rehydrated from a
|
||||
# dict-shaped context), `None` (attribute missing), or something unexpected
|
||||
# from a buggy upstream. The coercer must collapse all of that to
|
||||
# `SubjectType | None` so `check_surface` can do a clean set-membership
|
||||
# check and emit a clean audit payload.
|
||||
|
||||
|
||||
def test_coerce_subject_type_returns_none_for_none():
|
||||
assert _coerce_subject_type(None) is None
|
||||
|
||||
|
||||
def test_coerce_subject_type_returns_enum_instance_unchanged():
|
||||
# Identity matters: we don't want to round-trip through the string
|
||||
# constructor for an already-valid enum.
|
||||
assert _coerce_subject_type(SubjectType.ACCOUNT) is SubjectType.ACCOUNT
|
||||
assert _coerce_subject_type(SubjectType.EXTERNAL_SSO) is SubjectType.EXTERNAL_SSO
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("raw", "expected"),
|
||||
[
|
||||
("account", SubjectType.ACCOUNT),
|
||||
("external_sso", SubjectType.EXTERNAL_SSO),
|
||||
],
|
||||
)
|
||||
def test_coerce_subject_type_parses_known_strings(raw: str, expected: SubjectType):
|
||||
assert _coerce_subject_type(raw) is expected
|
||||
|
||||
|
||||
def test_coerce_subject_type_raises_on_unknown_string():
|
||||
# Unknown strings reach `SubjectType(raw)` which raises ValueError.
|
||||
# We surface that loudly rather than silently returning None, because
|
||||
# a string that *looks* like a subject type but isn't is almost
|
||||
# certainly an upstream bug worth catching.
|
||||
with pytest.raises(ValueError):
|
||||
_coerce_subject_type("not_a_subject")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("raw", [123, 1.5, b"account", object(), ["account"], {"account"}])
|
||||
def test_coerce_subject_type_returns_none_for_non_string_non_enum(raw: object):
|
||||
assert _coerce_subject_type(raw) is None
|
||||
@ -1,156 +0,0 @@
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.openapi.auth.verify import (
|
||||
check_acl,
|
||||
check_app_access,
|
||||
check_membership,
|
||||
check_private_app_permission,
|
||||
check_scope,
|
||||
)
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models.account import Tenant
|
||||
from models.model import App
|
||||
from services.enterprise.enterprise_service import WebAppAccessMode
|
||||
|
||||
|
||||
def _data(**kwargs) -> AuthData:
|
||||
defaults: dict = {"token_type": TokenType.OAUTH_ACCOUNT, "token_hash": "hash", "scopes": frozenset({Scope.FULL})}
|
||||
defaults.update(kwargs)
|
||||
return AuthData(**defaults)
|
||||
|
||||
|
||||
# --- check_scope ---
|
||||
|
||||
def test_check_scope_passes_when_required_is_none():
|
||||
check_scope(_data(required_scope=None))
|
||||
|
||||
|
||||
def test_check_scope_passes_when_full_in_scopes():
|
||||
check_scope(_data(required_scope=Scope.APPS_RUN, scopes=frozenset({Scope.FULL})))
|
||||
|
||||
|
||||
def test_check_scope_passes_when_exact_scope_present():
|
||||
check_scope(_data(required_scope=Scope.APPS_RUN, scopes=frozenset({Scope.APPS_RUN})))
|
||||
|
||||
|
||||
def test_check_scope_raises_forbidden_when_scope_missing():
|
||||
with pytest.raises(Forbidden, match="insufficient_scope"):
|
||||
check_scope(_data(required_scope=Scope.APPS_RUN, scopes=frozenset({Scope.APPS_READ})))
|
||||
|
||||
|
||||
# reject_external_sso is no longer a pipeline step — PipelineRouter._execute raises
|
||||
# Forbidden("external_sso_requires_ee") directly when route.required_edition is not satisfied.
|
||||
# Test coverage for this is in test_pipeline.py (test_router_rejects_token_type_on_wrong_edition).
|
||||
|
||||
# --- check_membership ---
|
||||
|
||||
def test_check_membership_raises_unauthorized_when_tenant_none():
|
||||
with pytest.raises(Unauthorized):
|
||||
check_membership(_data(tenant=None))
|
||||
|
||||
|
||||
def test_check_membership_calls_check_workspace_membership():
|
||||
tenant = MagicMock(spec=Tenant)
|
||||
tenant.id = "tenant-1"
|
||||
data = _data(
|
||||
account_id=uuid.uuid4(),
|
||||
token_hash="myhash",
|
||||
tenants={"tenant-1": True},
|
||||
tenant=tenant,
|
||||
)
|
||||
with patch("controllers.openapi.auth.verify.check_workspace_membership") as mock_cwm:
|
||||
check_membership(data)
|
||||
mock_cwm.assert_called_once_with(
|
||||
account_id=data.account_id,
|
||||
tenant_id="tenant-1",
|
||||
token_hash="myhash",
|
||||
membership_cache=data.tenants,
|
||||
)
|
||||
|
||||
|
||||
# --- check_app_access ---
|
||||
|
||||
def test_check_app_access_passes_when_tenant_none():
|
||||
check_app_access(_data(tenant=None))
|
||||
|
||||
|
||||
def test_check_app_access_passes_when_member():
|
||||
tenant = MagicMock(spec=Tenant)
|
||||
tenant.id = "t1"
|
||||
data = _data(account_id=uuid.uuid4(), tenant=tenant)
|
||||
with patch("controllers.openapi.auth.verify.TenantService.account_belongs_to_tenant", return_value=True):
|
||||
check_app_access(data)
|
||||
|
||||
|
||||
def test_check_app_access_raises_when_not_member():
|
||||
tenant = MagicMock(spec=Tenant)
|
||||
tenant.id = "t1"
|
||||
data = _data(account_id=uuid.uuid4(), tenant=tenant)
|
||||
with patch("controllers.openapi.auth.verify.TenantService.account_belongs_to_tenant", return_value=False):
|
||||
with pytest.raises(Forbidden, match="subject_no_app_access"):
|
||||
check_app_access(data)
|
||||
|
||||
|
||||
# --- check_acl ---
|
||||
|
||||
def test_check_acl_raises_when_app_or_mode_missing():
|
||||
with pytest.raises(Forbidden):
|
||||
check_acl(_data(app=None, app_access_mode=None))
|
||||
|
||||
|
||||
def test_check_acl_account_allowed_for_public():
|
||||
app = MagicMock(spec=App)
|
||||
data = _data(token_type=TokenType.OAUTH_ACCOUNT, app=app, app_access_mode=WebAppAccessMode.PUBLIC)
|
||||
check_acl(data)
|
||||
|
||||
|
||||
def test_check_acl_external_sso_blocked_for_private():
|
||||
app = MagicMock(spec=App)
|
||||
data = _data(
|
||||
token_type=TokenType.OAUTH_EXTERNAL_SSO,
|
||||
app=app,
|
||||
app_access_mode=WebAppAccessMode.PRIVATE,
|
||||
)
|
||||
with pytest.raises(Forbidden, match="subject_not_allowed_for_access_mode"):
|
||||
check_acl(data)
|
||||
|
||||
|
||||
def test_check_acl_external_sso_allowed_for_sso_verified():
|
||||
app = MagicMock(spec=App)
|
||||
data = _data(
|
||||
token_type=TokenType.OAUTH_EXTERNAL_SSO,
|
||||
app=app,
|
||||
app_access_mode=WebAppAccessMode.SSO_VERIFIED,
|
||||
)
|
||||
check_acl(data)
|
||||
|
||||
|
||||
# --- check_private_app_permission ---
|
||||
|
||||
def test_check_private_app_permission_raises_when_app_none():
|
||||
with pytest.raises(Forbidden):
|
||||
check_private_app_permission(_data(app=None))
|
||||
|
||||
|
||||
def test_check_private_app_permission_raises_when_user_not_allowed():
|
||||
app = MagicMock(spec=App)
|
||||
app.id = "app-1"
|
||||
data = _data(account_id=uuid.uuid4(), app=app)
|
||||
target = "controllers.openapi.auth.verify.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp"
|
||||
with patch(target, return_value=False):
|
||||
with pytest.raises(Forbidden, match="user_not_allowed_for_private_app"):
|
||||
check_private_app_permission(data)
|
||||
|
||||
|
||||
def test_check_private_app_permission_passes_when_allowed():
|
||||
app = MagicMock(spec=App)
|
||||
app.id = "app-1"
|
||||
data = _data(account_id=uuid.uuid4(), app=app)
|
||||
target = "controllers.openapi.auth.verify.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp"
|
||||
with patch(target, return_value=True):
|
||||
check_private_app_permission(data)
|
||||
@ -1,36 +1,20 @@
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.openapi.auth.pipeline import PipelineRouter
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
|
||||
|
||||
def _stub_execute(self, args, kwargs, view, *, scope=None, allowed_token_types=None, edition=None):
|
||||
"""Bypass all auth logic; inject minimal AuthData and call the view directly."""
|
||||
kwargs["auth_data"] = AuthData(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
account_id=uuid.uuid4(),
|
||||
token_hash="test",
|
||||
token_id=uuid.uuid4(),
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
required_scope=scope,
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bypass_pipeline(monkeypatch):
|
||||
"""Stub PipelineRouter._execute so endpoints skip real auth at request time.
|
||||
"""Stub Pipeline.run so endpoint decoration does not invoke real auth.
|
||||
|
||||
Module-level @auth_router.guard(...) captures the real router at import
|
||||
time — patching guard itself does nothing. Patching _execute on the class
|
||||
is the seam that fires at request time.
|
||||
Module-level @OAUTH_BEARER_PIPELINE.guard(...) captures the real
|
||||
pipeline at import time; mocking the module attribute does not undo
|
||||
that. Patching Pipeline.run on the class is the bypass that actually
|
||||
works.
|
||||
"""
|
||||
monkeypatch.setattr(PipelineRouter, "_execute", _stub_execute)
|
||||
monkeypatch.setattr(Pipeline, "run", lambda self, ctx: None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -86,7 +86,7 @@ def test_subject_match_for_account_filters_by_account_id():
|
||||
"""Account subject scopes queries via account_id."""
|
||||
import uuid as _uuid
|
||||
|
||||
from libs.oauth_bearer import AuthContext, SubjectType, TokenType
|
||||
from libs.oauth_bearer import AuthContext, SubjectType
|
||||
from services.oauth_device_flow import subject_match_clauses
|
||||
|
||||
aid = _uuid.uuid4()
|
||||
@ -98,7 +98,7 @@ def test_subject_match_for_account_filters_by_account_id():
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({"full"}),
|
||||
token_id=_uuid.uuid4(),
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
source="oauth_account",
|
||||
expires_at=None,
|
||||
token_hash="h1",
|
||||
verified_tenants={},
|
||||
@ -116,7 +116,7 @@ def test_subject_match_for_external_sso_filters_by_email_and_issuer():
|
||||
"""
|
||||
import uuid as _uuid
|
||||
|
||||
from libs.oauth_bearer import AuthContext, SubjectType, TokenType
|
||||
from libs.oauth_bearer import AuthContext, SubjectType
|
||||
from services.oauth_device_flow import subject_match_clauses
|
||||
|
||||
ctx = AuthContext(
|
||||
@ -127,7 +127,7 @@ def test_subject_match_for_external_sso_filters_by_email_and_issuer():
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({"apps:run"}),
|
||||
token_id=_uuid.uuid4(),
|
||||
token_type=TokenType.OAUTH_EXTERNAL_SSO,
|
||||
source="oauth_external_sso",
|
||||
expires_at=None,
|
||||
token_hash="h1",
|
||||
verified_tenants={},
|
||||
|
||||
@ -57,11 +57,7 @@ def test_stop_task_endpoint_registered(openapi_app):
|
||||
|
||||
|
||||
def test_stop_task_calls_queue_manager_and_graph_engine(app, bypass_pipeline, monkeypatch):
|
||||
import uuid
|
||||
|
||||
from controllers.openapi.app_run import AppRunTaskStopApi
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
|
||||
queue_mock = Mock()
|
||||
graph_mock = Mock()
|
||||
@ -73,23 +69,15 @@ def test_stop_task_calls_queue_manager_and_graph_engine(app, bypass_pipeline, mo
|
||||
monkeypatch.setattr(run_module, "GraphEngineManager", graph_mock)
|
||||
monkeypatch.setattr(run_module, "redis_client", object())
|
||||
|
||||
auth_data = AuthData.model_construct(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
account_id=uuid.uuid4(),
|
||||
token_hash="test",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
app=SimpleNamespace(id="app-1", tenant_id="t-1"),
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
)
|
||||
|
||||
api = AppRunTaskStopApi()
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/tasks/task-1/stop", method="POST"):
|
||||
result = api.post.__wrapped__(
|
||||
api,
|
||||
app_id="app-1",
|
||||
task_id="task-1",
|
||||
auth_data=auth_data,
|
||||
app_model=SimpleNamespace(id="app-1", tenant_id="t-1"),
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
)
|
||||
|
||||
queue_mock.set_stop_flag_no_user_check.assert_called_once_with("task-1")
|
||||
|
||||
@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
@ -12,23 +11,9 @@ from unittest.mock import Mock
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models.human_input import RecipientType
|
||||
|
||||
|
||||
def _make_auth_data(app_model, caller, caller_kind):
|
||||
return AuthData.model_construct(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
account_id=uuid.uuid4(),
|
||||
token_hash="test",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
app=app_model,
|
||||
caller=caller,
|
||||
caller_kind=caller_kind,
|
||||
)
|
||||
|
||||
|
||||
class TestOpenApiHumanInputFormGet:
|
||||
def test_get_success(self, app, bypass_pipeline, monkeypatch):
|
||||
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
|
||||
@ -58,14 +43,15 @@ class TestOpenApiHumanInputFormGet:
|
||||
|
||||
api = OpenApiWorkflowHumanInputFormApi()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
caller = SimpleNamespace(id="acct-1")
|
||||
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/tok-1"):
|
||||
resp = api.get.__wrapped__(
|
||||
api,
|
||||
app_id="app-1",
|
||||
form_token="tok-1",
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
app_model=app_model,
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
)
|
||||
|
||||
payload = json.loads(resp.get_data(as_text=True))
|
||||
@ -85,7 +71,6 @@ class TestOpenApiHumanInputFormGet:
|
||||
|
||||
api = OpenApiWorkflowHumanInputFormApi()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
caller = SimpleNamespace(id="acct-1")
|
||||
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/bad"):
|
||||
with pytest.raises(NotFound):
|
||||
@ -93,7 +78,9 @@ class TestOpenApiHumanInputFormGet:
|
||||
api,
|
||||
app_id="app-1",
|
||||
form_token="bad",
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
app_model=app_model,
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
)
|
||||
|
||||
def test_get_form_wrong_app(self, app, bypass_pipeline, monkeypatch):
|
||||
@ -110,7 +97,6 @@ class TestOpenApiHumanInputFormGet:
|
||||
|
||||
api = OpenApiWorkflowHumanInputFormApi()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
caller = SimpleNamespace(id="acct-1")
|
||||
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/tok-1"):
|
||||
with pytest.raises(NotFound):
|
||||
@ -118,7 +104,9 @@ class TestOpenApiHumanInputFormGet:
|
||||
api,
|
||||
app_id="app-1",
|
||||
form_token="tok-1",
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
app_model=app_model,
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
)
|
||||
|
||||
def test_get_form_wrong_surface(self, app, bypass_pipeline, monkeypatch):
|
||||
@ -138,7 +126,6 @@ class TestOpenApiHumanInputFormGet:
|
||||
|
||||
api = OpenApiWorkflowHumanInputFormApi()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
caller = SimpleNamespace(id="acct-1")
|
||||
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/tok-1"):
|
||||
with pytest.raises(NotFound):
|
||||
@ -146,7 +133,9 @@ class TestOpenApiHumanInputFormGet:
|
||||
api,
|
||||
app_id="app-1",
|
||||
form_token="tok-1",
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
app_model=app_model,
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
)
|
||||
|
||||
|
||||
@ -183,7 +172,9 @@ class TestOpenApiHumanInputFormPost:
|
||||
api,
|
||||
app_id="app-1",
|
||||
form_token="tok-1",
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
app_model=app_model,
|
||||
caller=caller,
|
||||
caller_kind="account",
|
||||
)
|
||||
|
||||
service_mock.submit_form_by_token.assert_called_once_with(
|
||||
@ -220,7 +211,9 @@ class TestOpenApiHumanInputFormPost:
|
||||
api,
|
||||
app_id="app-1",
|
||||
form_token="tok-1",
|
||||
auth_data=_make_auth_data(app_model, caller, "end_user"),
|
||||
app_model=app_model,
|
||||
caller=caller,
|
||||
caller_kind="end_user",
|
||||
)
|
||||
|
||||
service_mock.submit_form_by_token.assert_called_once_with(
|
||||
|
||||
@ -3,30 +3,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
|
||||
def _make_auth_data(app_model, caller, caller_kind):
|
||||
return AuthData.model_construct(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
account_id=uuid.uuid4(),
|
||||
token_hash="test",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
app=app_model,
|
||||
caller=caller,
|
||||
caller_kind=caller_kind,
|
||||
)
|
||||
|
||||
|
||||
def _make_workflow_run(
|
||||
*,
|
||||
app_id="app-1",
|
||||
@ -65,7 +50,6 @@ class TestOpenApiWorkflowEventsApi:
|
||||
from models.model import AppMode
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
|
||||
caller = SimpleNamespace(id="acct-1")
|
||||
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"):
|
||||
with pytest.raises(NotFound):
|
||||
@ -73,7 +57,9 @@ class TestOpenApiWorkflowEventsApi:
|
||||
api,
|
||||
app_id="app-1",
|
||||
task_id="wf-run-1",
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
app_model=app_model,
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
)
|
||||
|
||||
def test_not_found_when_run_belongs_to_different_app(self, app, bypass_pipeline, monkeypatch):
|
||||
@ -91,7 +77,6 @@ class TestOpenApiWorkflowEventsApi:
|
||||
from models.model import AppMode
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
|
||||
caller = SimpleNamespace(id="acct-1")
|
||||
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"):
|
||||
with pytest.raises(NotFound):
|
||||
@ -99,7 +84,9 @@ class TestOpenApiWorkflowEventsApi:
|
||||
api,
|
||||
app_id="app-1",
|
||||
task_id="wf-run-1",
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
app_model=app_model,
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
)
|
||||
|
||||
def test_account_caller_checks_created_by_account(self, app, bypass_pipeline, monkeypatch):
|
||||
@ -128,7 +115,6 @@ class TestOpenApiWorkflowEventsApi:
|
||||
from models.model import AppMode
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
|
||||
caller = SimpleNamespace(id="acct-1")
|
||||
|
||||
api = self._get_api()
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"):
|
||||
@ -137,7 +123,9 @@ class TestOpenApiWorkflowEventsApi:
|
||||
api,
|
||||
app_id="app-1",
|
||||
task_id="wf-run-1",
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
app_model=app_model,
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
)
|
||||
assert resp.mimetype == "text/event-stream"
|
||||
|
||||
@ -155,7 +143,6 @@ class TestOpenApiWorkflowEventsApi:
|
||||
from models.model import AppMode
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
|
||||
caller = SimpleNamespace(id="acct-1")
|
||||
|
||||
api = self._get_api()
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"):
|
||||
@ -164,7 +151,9 @@ class TestOpenApiWorkflowEventsApi:
|
||||
api,
|
||||
app_id="app-1",
|
||||
task_id="wf-run-1",
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
app_model=app_model,
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
)
|
||||
|
||||
def test_end_user_caller_checks_created_by_end_user(self, app, bypass_pipeline, monkeypatch):
|
||||
@ -190,7 +179,6 @@ class TestOpenApiWorkflowEventsApi:
|
||||
from models.model import AppMode
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
|
||||
caller = SimpleNamespace(id="eu-1")
|
||||
|
||||
api = self._get_api()
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"):
|
||||
@ -198,7 +186,9 @@ class TestOpenApiWorkflowEventsApi:
|
||||
api,
|
||||
app_id="app-1",
|
||||
task_id="wf-run-1",
|
||||
auth_data=_make_auth_data(app_model, caller, "end_user"),
|
||||
app_model=app_model,
|
||||
caller=SimpleNamespace(id="eu-1"),
|
||||
caller_kind="end_user",
|
||||
)
|
||||
assert resp.mimetype == "text/event-stream"
|
||||
|
||||
@ -232,7 +222,6 @@ class TestOpenApiWorkflowEventsApi:
|
||||
from models.model import AppMode
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
|
||||
caller = SimpleNamespace(id="acct-1")
|
||||
|
||||
api = self._get_api()
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"):
|
||||
@ -240,7 +229,9 @@ class TestOpenApiWorkflowEventsApi:
|
||||
api,
|
||||
app_id="app-1",
|
||||
task_id="wf-run-1",
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
app_model=app_model,
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
)
|
||||
assert resp.mimetype == "text/event-stream"
|
||||
chunks = list(resp.response)
|
||||
|
||||
@ -11,7 +11,6 @@ from libs.oauth_bearer import (
|
||||
SubjectType,
|
||||
TokenKind,
|
||||
TokenKindRegistry,
|
||||
TokenType,
|
||||
)
|
||||
|
||||
|
||||
@ -22,7 +21,7 @@ def _registry_with_resolver(resolver) -> TokenKindRegistry:
|
||||
prefix="dfoa_",
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
source="oauth_account",
|
||||
resolver=resolver,
|
||||
)
|
||||
]
|
||||
@ -64,7 +63,7 @@ def test_unknown_prefix_raises_generic_invalid_bearer():
|
||||
prefix="dfoa_",
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
source="oauth_account",
|
||||
resolver=MagicMock(),
|
||||
)
|
||||
]
|
||||
|
||||
@ -19,7 +19,6 @@ from libs.oauth_bearer import (
|
||||
AuthContext,
|
||||
Scope,
|
||||
SubjectType,
|
||||
TokenType,
|
||||
require_scope,
|
||||
reset_auth_ctx,
|
||||
set_auth_ctx,
|
||||
@ -51,7 +50,7 @@ def _ctx(scopes) -> AuthContext:
|
||||
client_id="difyctl",
|
||||
scopes=scopes,
|
||||
token_id=uuid.uuid4(),
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
source="oauth_account",
|
||||
expires_at=None,
|
||||
token_hash="h1",
|
||||
verified_tenants={},
|
||||
|
||||
@ -8,7 +8,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from libs.oauth_bearer import AuthContext, Scope, SubjectType, TokenType, require_workspace_member
|
||||
from libs.oauth_bearer import AuthContext, Scope, SubjectType, require_workspace_member
|
||||
|
||||
|
||||
def _ctx(verified: dict[str, bool] | None = None, *, account: bool = True) -> AuthContext:
|
||||
@ -20,7 +20,7 @@ def _ctx(verified: dict[str, bool] | None = None, *, account: bool = True) -> Au
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
token_id=uuid.uuid4(),
|
||||
token_type=TokenType.OAUTH_ACCOUNT if account else TokenType.OAUTH_EXTERNAL_SSO,
|
||||
source="oauth_account",
|
||||
expires_at=None,
|
||||
token_hash="h1",
|
||||
verified_tenants=dict(verified or {}),
|
||||
|
||||
@ -3079,20 +3079,12 @@
|
||||
"count": 2
|
||||
}
|
||||
},
|
||||
"web/app/components/tools/edit-custom-collection-modal/config-credentials.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"web/app/components/tools/edit-custom-collection-modal/get-schema.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"web/app/components/tools/edit-custom-collection-modal/index.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 1
|
||||
},
|
||||
"react/set-state-in-effect": {
|
||||
"count": 4
|
||||
},
|
||||
|
||||
@ -141,18 +141,6 @@ describe('DuplicateAppModal', () => {
|
||||
/>,
|
||||
)
|
||||
|
||||
await user.click(screen.getByText('open-icon-picker'))
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText('Search emojis...')).toBeInTheDocument()
|
||||
})
|
||||
const emojiButton = document.querySelector('em-emoji')?.closest('button')
|
||||
expect(emojiButton).toBeTruthy()
|
||||
await user.click(emojiButton!)
|
||||
await user.click(screen.getByRole('button', { name: '#E4FBCC' }))
|
||||
await user.click(screen.getByRole('button', { name: /iconPicker\.ok/ }))
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByPlaceholderText('Search emojis...')).not.toBeInTheDocument()
|
||||
})
|
||||
await user.click(screen.getByText('open-icon-picker'))
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText('Search emojis...')).toBeInTheDocument()
|
||||
@ -165,9 +153,9 @@ describe('DuplicateAppModal', () => {
|
||||
|
||||
expect(onConfirm).toHaveBeenCalledWith(expect.objectContaining({
|
||||
name: 'Image App',
|
||||
icon_type: 'emoji',
|
||||
icon: expect.any(String),
|
||||
icon_background: '#E4FBCC',
|
||||
icon_type: 'image',
|
||||
icon: 'original-file',
|
||||
icon_background: undefined,
|
||||
}))
|
||||
})
|
||||
})
|
||||
|
||||
@ -14,12 +14,13 @@ import {
|
||||
} from '@langgenius/dify-ui/drawer'
|
||||
import { FieldItem, FieldLabel, FieldRoot } from '@langgenius/dify-ui/field'
|
||||
import { FieldsetLegend, FieldsetRoot } from '@langgenius/dify-ui/fieldset'
|
||||
import { Input } from '@langgenius/dify-ui/input'
|
||||
import { Radio } from '@langgenius/dify-ui/radio'
|
||||
import { RadioGroup } from '@langgenius/dify-ui/radio-group'
|
||||
import { ScrollArea } from '@langgenius/dify-ui/scroll-area'
|
||||
import { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Infotip } from '@/app/components/base/infotip'
|
||||
import Input from '@/app/components/base/input'
|
||||
import { AuthHeaderPrefix, AuthType } from '@/app/components/tools/types'
|
||||
|
||||
type Props = {
|
||||
@ -41,11 +42,11 @@ function SelectItem({ text, value, isChecked }: ItemProps) {
|
||||
<FieldLabel
|
||||
className={cn(
|
||||
isChecked ? 'border-2 border-util-colors-indigo-indigo-600 bg-components-panel-on-panel-item-bg shadow-sm' : 'border border-components-card-border',
|
||||
'mb-2 flex h-9 w-37.5 cursor-pointer items-center space-x-2 rounded-xl bg-components-panel-on-panel-item-bg pl-3 text-left outline-hidden hover:bg-components-panel-on-panel-item-bg-hover focus-visible:ring-1 focus-visible:ring-components-input-border-hover',
|
||||
'flex h-9 w-full min-w-0 cursor-pointer items-center gap-2 rounded-xl bg-components-panel-on-panel-item-bg px-3 text-left outline-hidden hover:bg-components-panel-on-panel-item-bg-hover focus-visible:ring-1 focus-visible:ring-components-input-border-hover',
|
||||
)}
|
||||
>
|
||||
<Radio value={value} />
|
||||
<div className="system-sm-regular text-text-primary">{text}</div>
|
||||
<div className="min-w-0 truncate system-sm-regular text-text-primary">{text}</div>
|
||||
</FieldLabel>
|
||||
</FieldItem>
|
||||
)
|
||||
@ -104,7 +105,7 @@ export default function ConfigCredential({
|
||||
: 'data-[swipe-direction=right]:right-2',
|
||||
)}
|
||||
>
|
||||
<DrawerContent className="flex min-h-0 flex-1 flex-col p-0 pb-0">
|
||||
<DrawerContent className="flex min-h-0 flex-1 flex-col overflow-hidden p-0 pb-0">
|
||||
<div className="shrink-0 border-b border-divider-regular py-4">
|
||||
<div className="flex h-6 items-center justify-between pr-5 pl-6">
|
||||
<DrawerTitle className="min-w-0 truncate system-xl-semibold text-text-primary">
|
||||
@ -116,128 +117,132 @@ export default function ConfigCredential({
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="min-h-0 flex-1 overflow-y-auto px-6 pt-2">
|
||||
<div className="space-y-4">
|
||||
<FieldRoot name="auth_type" className="contents">
|
||||
<FieldsetRoot
|
||||
render={(
|
||||
<RadioGroup<AuthType>
|
||||
className="space-x-3"
|
||||
value={tempCredential.auth_type}
|
||||
onValueChange={handleAuthTypeChange}
|
||||
<ScrollArea
|
||||
className="min-h-0 flex-1 overflow-hidden"
|
||||
slotClassNames={{
|
||||
viewport: 'overscroll-contain',
|
||||
content: 'space-y-4 pt-2 pr-8 pl-6',
|
||||
}}
|
||||
>
|
||||
<FieldRoot name="auth_type" className="contents">
|
||||
<FieldsetRoot
|
||||
render={(
|
||||
<RadioGroup<AuthType>
|
||||
className="grid grid-cols-[repeat(auto-fit,minmax(8.5rem,1fr))] gap-2"
|
||||
value={tempCredential.auth_type}
|
||||
onValueChange={handleAuthTypeChange}
|
||||
/>
|
||||
)}
|
||||
>
|
||||
<FieldsetLegend className="col-span-full py-2 system-sm-medium text-text-primary">
|
||||
{t('createTool.authMethod.type', { ns: 'tools' })}
|
||||
</FieldsetLegend>
|
||||
<SelectItem
|
||||
text={t('createTool.authMethod.types.none', { ns: 'tools' })}
|
||||
value={AuthType.none}
|
||||
isChecked={tempCredential.auth_type === AuthType.none}
|
||||
/>
|
||||
<SelectItem
|
||||
text={t('createTool.authMethod.types.api_key_header', { ns: 'tools' })}
|
||||
value={AuthType.apiKeyHeader}
|
||||
isChecked={tempCredential.auth_type === AuthType.apiKeyHeader}
|
||||
/>
|
||||
<SelectItem
|
||||
text={t('createTool.authMethod.types.api_key_query', { ns: 'tools' })}
|
||||
value={AuthType.apiKeyQuery}
|
||||
isChecked={tempCredential.auth_type === AuthType.apiKeyQuery}
|
||||
/>
|
||||
</FieldsetRoot>
|
||||
</FieldRoot>
|
||||
{tempCredential.auth_type === AuthType.apiKeyHeader && (
|
||||
<>
|
||||
<FieldRoot name="api_key_header_prefix" className="contents">
|
||||
<FieldsetRoot
|
||||
render={(
|
||||
<RadioGroup<AuthHeaderPrefix>
|
||||
className="grid grid-cols-[repeat(auto-fit,minmax(8.5rem,1fr))] gap-2"
|
||||
value={tempCredential.api_key_header_prefix}
|
||||
onValueChange={value => setTempCredential({ ...tempCredential, api_key_header_prefix: value })}
|
||||
/>
|
||||
)}
|
||||
>
|
||||
<FieldsetLegend className="col-span-full py-2 system-sm-medium text-text-primary">
|
||||
{t('createTool.authHeaderPrefix.title', { ns: 'tools' })}
|
||||
</FieldsetLegend>
|
||||
<SelectItem
|
||||
text={t('createTool.authHeaderPrefix.types.basic', { ns: 'tools' })}
|
||||
value={AuthHeaderPrefix.basic}
|
||||
isChecked={tempCredential.api_key_header_prefix === AuthHeaderPrefix.basic}
|
||||
/>
|
||||
)}
|
||||
>
|
||||
<FieldsetLegend className="py-2 system-sm-medium text-text-primary">
|
||||
{t('createTool.authMethod.type', { ns: 'tools' })}
|
||||
</FieldsetLegend>
|
||||
<SelectItem
|
||||
text={t('createTool.authMethod.types.none', { ns: 'tools' })}
|
||||
value={AuthType.none}
|
||||
isChecked={tempCredential.auth_type === AuthType.none}
|
||||
/>
|
||||
<SelectItem
|
||||
text={t('createTool.authMethod.types.api_key_header', { ns: 'tools' })}
|
||||
value={AuthType.apiKeyHeader}
|
||||
isChecked={tempCredential.auth_type === AuthType.apiKeyHeader}
|
||||
/>
|
||||
<SelectItem
|
||||
text={t('createTool.authMethod.types.api_key_query', { ns: 'tools' })}
|
||||
value={AuthType.apiKeyQuery}
|
||||
isChecked={tempCredential.auth_type === AuthType.apiKeyQuery}
|
||||
/>
|
||||
</FieldsetRoot>
|
||||
</FieldRoot>
|
||||
{tempCredential.auth_type === AuthType.apiKeyHeader && (
|
||||
<>
|
||||
<FieldRoot name="api_key_header_prefix" className="contents">
|
||||
<FieldsetRoot
|
||||
render={(
|
||||
<RadioGroup<AuthHeaderPrefix>
|
||||
className="space-x-3"
|
||||
value={tempCredential.api_key_header_prefix}
|
||||
onValueChange={value => setTempCredential({ ...tempCredential, api_key_header_prefix: value })}
|
||||
/>
|
||||
)}
|
||||
<SelectItem
|
||||
text={t('createTool.authHeaderPrefix.types.bearer', { ns: 'tools' })}
|
||||
value={AuthHeaderPrefix.bearer}
|
||||
isChecked={tempCredential.api_key_header_prefix === AuthHeaderPrefix.bearer}
|
||||
/>
|
||||
<SelectItem
|
||||
text={t('createTool.authHeaderPrefix.types.custom', { ns: 'tools' })}
|
||||
value={AuthHeaderPrefix.custom}
|
||||
isChecked={tempCredential.api_key_header_prefix === AuthHeaderPrefix.custom}
|
||||
/>
|
||||
</FieldsetRoot>
|
||||
</FieldRoot>
|
||||
<div>
|
||||
<div className="flex items-center py-2 system-sm-medium text-text-primary">
|
||||
{t('createTool.authMethod.key', { ns: 'tools' })}
|
||||
<Infotip
|
||||
aria-label={t('createTool.authMethod.keyTooltip', { ns: 'tools' })}
|
||||
className="ml-0.5 size-4"
|
||||
popupClassName="w-[261px] text-text-tertiary"
|
||||
>
|
||||
<FieldsetLegend className="py-2 system-sm-medium text-text-primary">
|
||||
{t('createTool.authHeaderPrefix.title', { ns: 'tools' })}
|
||||
</FieldsetLegend>
|
||||
<SelectItem
|
||||
text={t('createTool.authHeaderPrefix.types.basic', { ns: 'tools' })}
|
||||
value={AuthHeaderPrefix.basic}
|
||||
isChecked={tempCredential.api_key_header_prefix === AuthHeaderPrefix.basic}
|
||||
/>
|
||||
<SelectItem
|
||||
text={t('createTool.authHeaderPrefix.types.bearer', { ns: 'tools' })}
|
||||
value={AuthHeaderPrefix.bearer}
|
||||
isChecked={tempCredential.api_key_header_prefix === AuthHeaderPrefix.bearer}
|
||||
/>
|
||||
<SelectItem
|
||||
text={t('createTool.authHeaderPrefix.types.custom', { ns: 'tools' })}
|
||||
value={AuthHeaderPrefix.custom}
|
||||
isChecked={tempCredential.api_key_header_prefix === AuthHeaderPrefix.custom}
|
||||
/>
|
||||
</FieldsetRoot>
|
||||
</FieldRoot>
|
||||
<div>
|
||||
<div className="flex items-center py-2 system-sm-medium text-text-primary">
|
||||
{t('createTool.authMethod.key', { ns: 'tools' })}
|
||||
<Infotip
|
||||
aria-label={t('createTool.authMethod.keyTooltip', { ns: 'tools' })}
|
||||
className="ml-0.5 size-4"
|
||||
popupClassName="w-[261px] text-text-tertiary"
|
||||
>
|
||||
{t('createTool.authMethod.keyTooltip', { ns: 'tools' })}
|
||||
</Infotip>
|
||||
</div>
|
||||
<Input
|
||||
value={tempCredential.api_key_header}
|
||||
onChange={e => setTempCredential({ ...tempCredential, api_key_header: e.target.value })}
|
||||
placeholder={t('createTool.authMethod.types.apiKeyPlaceholder', { ns: 'tools' })!}
|
||||
/>
|
||||
{t('createTool.authMethod.keyTooltip', { ns: 'tools' })}
|
||||
</Infotip>
|
||||
</div>
|
||||
<div>
|
||||
<div className="py-2 system-sm-medium text-text-primary">{t('createTool.authMethod.value', { ns: 'tools' })}</div>
|
||||
<Input
|
||||
value={tempCredential.api_key_value}
|
||||
onChange={e => setTempCredential({ ...tempCredential, api_key_value: e.target.value })}
|
||||
placeholder={t('createTool.authMethod.types.apiValuePlaceholder', { ns: 'tools' })!}
|
||||
/>
|
||||
<Input
|
||||
value={tempCredential.api_key_header}
|
||||
onChange={e => setTempCredential({ ...tempCredential, api_key_header: e.target.value })}
|
||||
placeholder={t('createTool.authMethod.types.apiKeyPlaceholder', { ns: 'tools' })!}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<div className="py-2 system-sm-medium text-text-primary">{t('createTool.authMethod.value', { ns: 'tools' })}</div>
|
||||
<Input
|
||||
value={tempCredential.api_key_value}
|
||||
onChange={e => setTempCredential({ ...tempCredential, api_key_value: e.target.value })}
|
||||
placeholder={t('createTool.authMethod.types.apiValuePlaceholder', { ns: 'tools' })!}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
{tempCredential.auth_type === AuthType.apiKeyQuery && (
|
||||
<>
|
||||
<div>
|
||||
<div className="flex items-center py-2 system-sm-medium text-text-primary">
|
||||
{t('createTool.authMethod.queryParam', { ns: 'tools' })}
|
||||
<Infotip
|
||||
aria-label={t('createTool.authMethod.queryParamTooltip', { ns: 'tools' })}
|
||||
className="ml-0.5 size-4"
|
||||
popupClassName="w-[261px] text-text-tertiary"
|
||||
>
|
||||
{t('createTool.authMethod.queryParamTooltip', { ns: 'tools' })}
|
||||
</Infotip>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
{tempCredential.auth_type === AuthType.apiKeyQuery && (
|
||||
<>
|
||||
<div>
|
||||
<div className="flex items-center py-2 system-sm-medium text-text-primary">
|
||||
{t('createTool.authMethod.queryParam', { ns: 'tools' })}
|
||||
<Infotip
|
||||
aria-label={t('createTool.authMethod.queryParamTooltip', { ns: 'tools' })}
|
||||
className="ml-0.5 size-4"
|
||||
popupClassName="w-[261px] text-text-tertiary"
|
||||
>
|
||||
{t('createTool.authMethod.queryParamTooltip', { ns: 'tools' })}
|
||||
</Infotip>
|
||||
</div>
|
||||
<Input
|
||||
value={tempCredential.api_key_query_param}
|
||||
onChange={e => setTempCredential({ ...tempCredential, api_key_query_param: e.target.value })}
|
||||
placeholder={t('createTool.authMethod.types.queryParamPlaceholder', { ns: 'tools' })!}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<div className="py-2 system-sm-medium text-text-primary">{t('createTool.authMethod.value', { ns: 'tools' })}</div>
|
||||
<Input
|
||||
value={tempCredential.api_key_value}
|
||||
onChange={e => setTempCredential({ ...tempCredential, api_key_value: e.target.value })}
|
||||
placeholder={t('createTool.authMethod.types.apiValuePlaceholder', { ns: 'tools' })!}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<Input
|
||||
value={tempCredential.api_key_query_param}
|
||||
onChange={e => setTempCredential({ ...tempCredential, api_key_query_param: e.target.value })}
|
||||
placeholder={t('createTool.authMethod.types.queryParamPlaceholder', { ns: 'tools' })!}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<div className="py-2 system-sm-medium text-text-primary">{t('createTool.authMethod.value', { ns: 'tools' })}</div>
|
||||
<Input
|
||||
value={tempCredential.api_key_value}
|
||||
onChange={e => setTempCredential({ ...tempCredential, api_key_value: e.target.value })}
|
||||
placeholder={t('createTool.authMethod.types.apiValuePlaceholder', { ns: 'tools' })!}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</ScrollArea>
|
||||
<div className="mt-4 flex shrink-0 justify-end space-x-2 px-6 py-4">
|
||||
<Button onClick={onHide}>{t('operation.cancel', { ns: 'common' })}</Button>
|
||||
<Button
|
||||
|
||||
@ -13,6 +13,8 @@ import {
|
||||
DrawerTitle,
|
||||
DrawerViewport,
|
||||
} from '@langgenius/dify-ui/drawer'
|
||||
import { Input } from '@langgenius/dify-ui/input'
|
||||
import { ScrollArea } from '@langgenius/dify-ui/scroll-area'
|
||||
import { Textarea } from '@langgenius/dify-ui/textarea'
|
||||
import { toast } from '@langgenius/dify-ui/toast'
|
||||
import { RiSettings2Line } from '@remixicon/react'
|
||||
@ -23,7 +25,6 @@ import { useEffect, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import AppIcon from '@/app/components/base/app-icon'
|
||||
import EmojiPicker from '@/app/components/base/emoji-picker'
|
||||
import Input from '@/app/components/base/input'
|
||||
import LabelSelector from '@/app/components/tools/labels/selector'
|
||||
import { parseParamsSchema } from '@/service/tools'
|
||||
import { LinkExternal02 } from '../../base/icons/src/vender/line/general'
|
||||
@ -220,7 +221,7 @@ const EditCustomCollectionModal: FC<Props> = ({
|
||||
: 'data-[swipe-direction=right]:right-2',
|
||||
)}
|
||||
>
|
||||
<DrawerContent className="flex min-h-0 flex-1 flex-col p-0 pb-0">
|
||||
<DrawerContent className="flex min-h-0 flex-1 flex-col overflow-hidden p-0 pb-0">
|
||||
<div className="shrink-0 border-b border-divider-regular py-4">
|
||||
<div className="flex h-6 items-center justify-between pr-5 pl-6">
|
||||
<DrawerTitle className="min-w-0 truncate system-xl-semibold text-text-primary">
|
||||
@ -233,8 +234,14 @@ const EditCustomCollectionModal: FC<Props> = ({
|
||||
</div>
|
||||
</div>
|
||||
<div className="min-h-0 flex-1">
|
||||
<div className="flex h-full flex-col">
|
||||
<div className="h-0 grow space-y-4 overflow-y-auto px-6 py-3">
|
||||
<div className="flex h-full min-h-0 flex-col">
|
||||
<ScrollArea
|
||||
className="min-h-0 flex-1 overflow-hidden"
|
||||
slotClassNames={{
|
||||
viewport: 'overscroll-contain',
|
||||
content: 'space-y-4 py-3 pr-8 pl-6',
|
||||
}}
|
||||
>
|
||||
<div>
|
||||
<div className="py-2 system-sm-medium text-text-primary">
|
||||
{t('createTool.name', { ns: 'tools' })}
|
||||
@ -373,7 +380,7 @@ const EditCustomCollectionModal: FC<Props> = ({
|
||||
/>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</ScrollArea>
|
||||
<div className={cn(isEdit ? 'justify-between' : 'justify-end', 'mt-2 flex shrink-0 rounded-b-[10px] border-t border-divider-regular bg-background-section-burn px-6 py-4')}>
|
||||
{
|
||||
isEdit && (
|
||||
|
||||
Reference in New Issue
Block a user