Files
dify/api/tests/unit_tests/controllers/openapi/auth/test_pipeline.py
Xiyuan Chen d2788d7aba feat(openapi): redesign auth pipeline with per-token-type routing (#36693)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-27 12:45:30 +00:00

270 lines
9.7 KiB
Python

import uuid
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from controllers.openapi.auth.data import AuthData, Edition
from controllers.openapi.auth.pipeline import AuthPipeline, PipelineRoute, PipelineRouter
from libs.oauth_bearer import Scope, TokenType
def _make_identity(
token_type=TokenType.OAUTH_ACCOUNT,
account_id=None,
scopes=None,
token_hash="testhash",
subject_email=None,
subject_issuer=None,
verified_tenants=None,
token_id=None,
):
identity = MagicMock()
identity.token_type = token_type
identity.account_id = account_id or uuid.uuid4()
identity.scopes = scopes or frozenset({Scope.FULL})
identity.token_hash = token_hash
identity.subject_email = subject_email
identity.subject_issuer = subject_issuer
identity.verified_tenants = verified_tenants or {}
identity.token_id = token_id or uuid.uuid4()
return identity
@pytest.fixture
def app():
return Flask(__name__)
def _make_router(token_type=TokenType.OAUTH_ACCOUNT, prepare=None, auth=None):
pipeline = AuthPipeline(prepare=prepare or [], auth=auth or [])
return PipelineRouter({token_type: PipelineRoute(pipeline)})
def _fake_identity():
return _make_identity()
# --- PipelineRouter.guard ---
def test_guard_passes_auth_data_to_view(app):
router = _make_router()
received = {}
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
with (
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()),
patch("controllers.openapi.auth.pipeline.reset_auth_ctx"),
):
mock_auth.return_value.authenticate.return_value = _fake_identity()
@router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def view(*, auth_data):
received["data"] = auth_data
view()
assert isinstance(received["data"], AuthData)
def test_guard_edition_gate_returns_404(app):
router = _make_router()
with app.test_request_context("/test"):
with patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE):
@router.guard(scope=Scope.FULL, edition=frozenset({Edition.EE}))
def view(*, auth_data):
pass
with pytest.raises(NotFound):
view()
def test_guard_token_type_gate_returns_403(app):
router = _make_router()
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
with (
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
patch("controllers.openapi.auth.pipeline.emit_wrong_surface"),
patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE),
):
identity = _fake_identity()
identity.token_type = TokenType.OAUTH_EXTERNAL_SSO
mock_auth.return_value.authenticate.return_value = identity
@router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def view(*, auth_data):
pass
with pytest.raises(Forbidden):
view()
def test_guard_unregistered_token_type_returns_403(app):
router = _make_router(token_type=TokenType.OAUTH_ACCOUNT)
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
with (
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE),
):
identity = _fake_identity()
identity.token_type = TokenType.OAUTH_EXTERNAL_SSO
mock_auth.return_value.authenticate.return_value = identity
@router.guard(scope=Scope.FULL)
def view(*, auth_data):
pass
with pytest.raises(Forbidden):
view()
def test_guard_no_bearer_returns_401(app):
router = _make_router()
with app.test_request_context("/test"):
with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value=None):
@router.guard(scope=Scope.FULL)
def view(*, auth_data):
pass
with pytest.raises(Unauthorized):
view()
def test_guard_runs_prepare_steps_in_order(app):
order = []
def p1(b):
order.append("p1")
def p2(b):
order.append("p2")
router = _make_router(prepare=[p1, p2])
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
with (
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()),
patch("controllers.openapi.auth.pipeline.reset_auth_ctx"),
):
mock_auth.return_value.authenticate.return_value = _fake_identity()
@router.guard(scope=Scope.FULL)
def view(*, auth_data):
pass
view()
assert order == ["p1", "p2"]
def test_guard_resets_auth_ctx_on_exception(app):
router = _make_router()
reset_called = []
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
with (
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value="tok"),
patch("controllers.openapi.auth.pipeline.reset_auth_ctx", side_effect=lambda t: reset_called.append(t)),
):
mock_auth.return_value.authenticate.return_value = _fake_identity()
@router.guard(scope=Scope.FULL)
def view(*, auth_data):
raise RuntimeError("boom")
with pytest.raises(RuntimeError):
view()
assert reset_called == ["tok"]
def test_router_rejects_token_type_on_wrong_edition(app):
pipeline = AuthPipeline(prepare=[], auth=[])
route = PipelineRoute(pipeline, required_edition=frozenset({Edition.EE}))
router = PipelineRouter({TokenType.OAUTH_EXTERNAL_SSO: route})
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
with (
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE),
):
identity = _make_identity(token_type=TokenType.OAUTH_EXTERNAL_SSO)
mock_auth.return_value.authenticate.return_value = identity
@router.guard(scope=Scope.APPS_RUN)
def view(*, auth_data):
pass
with pytest.raises(Forbidden):
view()
def test_guard_populates_external_identity_from_subject_email(app):
from controllers.openapi.auth.data import ExternalIdentity
router = _make_router(token_type=TokenType.OAUTH_EXTERNAL_SSO)
received = {}
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
with (
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()),
patch("controllers.openapi.auth.pipeline.reset_auth_ctx"),
):
identity = _make_identity(
token_type=TokenType.OAUTH_EXTERNAL_SSO,
subject_email="user@sso.com",
subject_issuer="https://idp.example.com",
)
mock_auth.return_value.authenticate.return_value = identity
@router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_EXTERNAL_SSO}))
def view(*, auth_data):
received["data"] = auth_data
view()
assert isinstance(received["data"].external_identity, ExternalIdentity)
assert received["data"].external_identity.email == "user@sso.com"
assert received["data"].external_identity.issuer == "https://idp.example.com"
def test_guard_no_external_identity_when_subject_email_absent(app):
router = _make_router()
received = {}
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
with (
patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"),
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth,
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()),
patch("controllers.openapi.auth.pipeline.reset_auth_ctx"),
):
mock_auth.return_value.authenticate.return_value = _make_identity(subject_email=None)
@router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def view(*, auth_data):
received["data"] = auth_data
view()
assert received["data"].external_identity is None