mirror of
https://github.com/langgenius/dify.git
synced 2026-05-27 04:16:16 +08:00
Replace the single mutable-context Pipeline with a two-phase, condition-driven system dispatched by token type. New architecture: - TokenType(StrEnum) replaces source: str on AuthContext / TokenKind - AuthPipeline: pure prepare→auth step runner; no guard() - PipelineRoute: binds AuthPipeline to an optional required_edition gate - PipelineRouter: single guard() entry point; runs edition/license/token-type pre-gates then dispatches to the registered pipeline for the token type - Cond / When: composable predicates for conditional step dispatch - AuthData: frozen Pydantic model produced by the prepare phase; carries token_id so endpoints don't need to call get_auth_ctx() for identity fields - Edition enum + current_edition(): CE / EE / SAAS discriminator Two pipelines in composition.py: - account_pipeline — OAUTH_ACCOUNT tokens - external_sso_pipeline — OAUTH_EXTERNAL_SSO tokens (EE enforced at route level) All /openapi/v1 endpoints migrated to auth_router.guard(). Old context.py, steps.py, strategies.py, surface_gate.py deleted. WORKSPACE_READ scope added; cached_verdicts renamed to membership_cache.
205 lines
7.4 KiB
Python
205 lines
7.4 KiB
Python
import uuid
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
from flask import Flask
|
|
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
|
|
|
from controllers.openapi.auth.data import AuthData, Edition
|
|
from controllers.openapi.auth.pipeline import AuthPipeline, PipelineRoute, PipelineRouter
|
|
from libs.oauth_bearer import Scope, TokenType
|
|
|
|
|
|
def _make_identity(
|
|
token_type=TokenType.OAUTH_ACCOUNT,
|
|
account_id=None,
|
|
scopes=None,
|
|
token_hash="testhash",
|
|
subject_email=None,
|
|
subject_issuer=None,
|
|
verified_tenants=None,
|
|
token_id=None,
|
|
):
|
|
identity = MagicMock()
|
|
identity.token_type = token_type
|
|
identity.account_id = account_id or uuid.uuid4()
|
|
identity.scopes = scopes or frozenset({Scope.FULL})
|
|
identity.token_hash = token_hash
|
|
identity.subject_email = subject_email
|
|
identity.subject_issuer = subject_issuer
|
|
identity.verified_tenants = verified_tenants or {}
|
|
identity.token_id = token_id or uuid.uuid4()
|
|
return identity
|
|
|
|
|
|
@pytest.fixture
|
|
def app():
|
|
return Flask(__name__)
|
|
|
|
|
|
def _make_router(token_type=TokenType.OAUTH_ACCOUNT, prepare=None, auth=None):
|
|
pipeline = AuthPipeline(prepare=prepare or [], auth=auth or [])
|
|
return PipelineRouter({token_type: PipelineRoute(pipeline)})
|
|
|
|
|
|
def _fake_identity():
|
|
return _make_identity()
|
|
|
|
|
|
# --- PipelineRouter.guard ---
|
|
|
|
def test_guard_passes_auth_data_to_view(app):
|
|
router = _make_router()
|
|
received = {}
|
|
|
|
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
|
|
with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), \
|
|
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, \
|
|
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()), \
|
|
patch("controllers.openapi.auth.pipeline.reset_auth_ctx"):
|
|
mock_auth.return_value.authenticate.return_value = _fake_identity()
|
|
|
|
@router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
|
def view(*, auth_data):
|
|
received["data"] = auth_data
|
|
|
|
view()
|
|
|
|
assert isinstance(received["data"], AuthData)
|
|
|
|
|
|
def test_guard_edition_gate_returns_404(app):
|
|
router = _make_router()
|
|
|
|
with app.test_request_context("/test"):
|
|
with patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE):
|
|
|
|
@router.guard(scope=Scope.FULL, edition=frozenset({Edition.EE}))
|
|
def view(*, auth_data):
|
|
pass
|
|
|
|
with pytest.raises(NotFound):
|
|
view()
|
|
|
|
|
|
def test_guard_token_type_gate_returns_403(app):
|
|
router = _make_router()
|
|
|
|
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
|
|
with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), \
|
|
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, \
|
|
patch("controllers.openapi.auth.pipeline.emit_wrong_surface"), \
|
|
patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE):
|
|
identity = _fake_identity()
|
|
identity.token_type = TokenType.OAUTH_EXTERNAL_SSO
|
|
mock_auth.return_value.authenticate.return_value = identity
|
|
|
|
@router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
|
def view(*, auth_data):
|
|
pass
|
|
|
|
with pytest.raises(Forbidden):
|
|
view()
|
|
|
|
|
|
def test_guard_unregistered_token_type_returns_403(app):
|
|
# Router has only OAUTH_ACCOUNT; present OAUTH_EXTERNAL_SSO without allowed_token_types gate
|
|
router = _make_router(token_type=TokenType.OAUTH_ACCOUNT)
|
|
|
|
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
|
|
with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), \
|
|
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, \
|
|
patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE):
|
|
identity = _fake_identity()
|
|
identity.token_type = TokenType.OAUTH_EXTERNAL_SSO
|
|
mock_auth.return_value.authenticate.return_value = identity
|
|
|
|
@router.guard(scope=Scope.FULL)
|
|
def view(*, auth_data):
|
|
pass
|
|
|
|
with pytest.raises(Forbidden):
|
|
view()
|
|
|
|
|
|
def test_guard_no_bearer_returns_401(app):
|
|
router = _make_router()
|
|
|
|
with app.test_request_context("/test"):
|
|
with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value=None):
|
|
|
|
@router.guard(scope=Scope.FULL)
|
|
def view(*, auth_data):
|
|
pass
|
|
|
|
with pytest.raises(Unauthorized):
|
|
view()
|
|
|
|
|
|
def test_guard_runs_prepare_steps_in_order(app):
|
|
order = []
|
|
|
|
def p1(b):
|
|
order.append("p1")
|
|
|
|
def p2(b):
|
|
order.append("p2")
|
|
|
|
router = _make_router(prepare=[p1, p2])
|
|
|
|
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
|
|
with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), \
|
|
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, \
|
|
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()), \
|
|
patch("controllers.openapi.auth.pipeline.reset_auth_ctx"):
|
|
mock_auth.return_value.authenticate.return_value = _fake_identity()
|
|
|
|
@router.guard(scope=Scope.FULL)
|
|
def view(*, auth_data):
|
|
pass
|
|
|
|
view()
|
|
|
|
assert order == ["p1", "p2"]
|
|
|
|
|
|
def test_guard_resets_auth_ctx_on_exception(app):
|
|
router = _make_router()
|
|
reset_called = []
|
|
|
|
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
|
|
with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), \
|
|
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, \
|
|
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value="tok"), \
|
|
patch("controllers.openapi.auth.pipeline.reset_auth_ctx", side_effect=lambda t: reset_called.append(t)):
|
|
mock_auth.return_value.authenticate.return_value = _fake_identity()
|
|
|
|
@router.guard(scope=Scope.FULL)
|
|
def view(*, auth_data):
|
|
raise RuntimeError("boom")
|
|
|
|
with pytest.raises(RuntimeError):
|
|
view()
|
|
|
|
assert reset_called == ["tok"]
|
|
|
|
|
|
def test_router_rejects_token_type_on_wrong_edition(app):
|
|
pipeline = AuthPipeline(prepare=[], auth=[])
|
|
route = PipelineRoute(pipeline, required_edition=frozenset({Edition.EE}))
|
|
router = PipelineRouter({TokenType.OAUTH_EXTERNAL_SSO: route})
|
|
|
|
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
|
|
with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), \
|
|
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, \
|
|
patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE):
|
|
identity = _make_identity(token_type=TokenType.OAUTH_EXTERNAL_SSO)
|
|
mock_auth.return_value.authenticate.return_value = identity
|
|
|
|
@router.guard(scope=Scope.APPS_RUN)
|
|
def view(*, auth_data):
|
|
pass
|
|
|
|
with pytest.raises(Forbidden):
|
|
view()
|