refactor: decouple Context from flask

This commit is contained in:
yunlu.wen
2026-05-23 10:32:41 +08:00
parent 341a82bf1e
commit 0c1b37687f
19 changed files with 251 additions and 157 deletions

View File

@ -1,10 +1,10 @@
from unittest.mock import MagicMock
from controllers.openapi.auth.context import Context
def test_context_starts_unpopulated():
ctx = Context(request=MagicMock(), required_scope="apps:run")
ctx = Context(required_scope="apps:run")
assert ctx.bearer_token is None
assert ctx.path_params == {}
assert ctx.subject_type is None
assert ctx.subject_email is None
assert ctx.account_id is None
@ -16,6 +16,6 @@ def test_context_starts_unpopulated():
def test_context_fields_are_mutable():
ctx = Context(request=MagicMock(), required_scope="apps:run")
ctx = Context(required_scope="apps:run")
ctx.scopes = frozenset({"full"})
assert "full" in ctx.scopes

View File

@ -1,5 +1,3 @@
from unittest.mock import MagicMock
import pytest
from flask import Flask
@ -17,7 +15,7 @@ def test_run_invokes_each_step_in_order():
def __call__(self, ctx):
calls.append(self.tag)
Pipeline(S("a"), S("b"), S("c")).run(Context(request=MagicMock(), required_scope="x"))
Pipeline(S("a"), S("b"), S("c")).run(Context(required_scope="x"))
assert calls == ["a", "b", "c"]
@ -33,7 +31,7 @@ def test_run_short_circuits_on_raise():
calls.append("ran")
with pytest.raises(RuntimeError):
Pipeline(Boom(), Tail()).run(Context(request=MagicMock(), required_scope="x"))
Pipeline(Boom(), Tail()).run(Context(required_scope="x"))
assert calls == []

View File

@ -1,5 +1,5 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
@ -9,10 +9,8 @@ from controllers.openapi.auth.steps import AppResolver
from models import TenantStatus
def _ctx(view_args):
req = MagicMock()
req.view_args = view_args
return Context(request=req, required_scope="apps:run")
def _ctx(path_params: dict[str, str] | None) -> Context:
return Context(required_scope="apps:run", path_params=path_params or {})
def _app(*, status="normal", enable_api=True):
@ -28,7 +26,9 @@ def test_resolver_rejects_missing_path_param():
AppResolver()(_ctx({}))
def test_resolver_rejects_none_view_args():
def test_resolver_rejects_empty_path_params():
# `Pipeline.guard` always seeds an empty dict when Flask reports no
# view args, so a missing `app_id` key surfaces here as BadRequest.
with pytest.raises(BadRequest):
AppResolver()(_ctx(None))

View File

@ -1,5 +1,5 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
from werkzeug.exceptions import Forbidden
@ -11,7 +11,7 @@ from libs.oauth_bearer import SubjectType
def _ctx(*, subject_type, account_id="acc1"):
c = Context(request=MagicMock(), required_scope="apps:run")
c = Context(required_scope="apps:run")
c.subject_type = subject_type
c.subject_email = "alice@example.com"
c.account_id = account_id

View File

@ -1,26 +1,31 @@
import uuid
from datetime import UTC, datetime
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
from flask import Flask, g
from flask import Flask
from werkzeug.exceptions import Unauthorized
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import BearerCheck
from libs.oauth_bearer import AuthContext, InvalidBearerError, Scope, SubjectType
from libs.oauth_bearer import (
AuthContext,
InvalidBearerError,
Scope,
SubjectType,
reset_auth_ctx,
try_get_auth_ctx,
)
def _ctx(headers):
req = MagicMock()
req.headers = headers
return Context(request=req, required_scope="apps:run")
def _ctx(bearer_token: str | None) -> Context:
return Context(required_scope="apps:run", bearer_token=bearer_token)
def test_bearer_check_rejects_missing_header():
app = Flask(__name__)
with app.test_request_context(), pytest.raises(Unauthorized):
BearerCheck()(_ctx({}))
BearerCheck()(_ctx(None))
@patch("controllers.openapi.auth.steps.get_authenticator")
@ -28,11 +33,11 @@ def test_bearer_check_rejects_unknown_prefix(get_auth):
get_auth.return_value.authenticate.side_effect = InvalidBearerError("unknown token prefix")
app = Flask(__name__)
with app.test_request_context(), pytest.raises(Unauthorized):
BearerCheck()(_ctx({"Authorization": "Bearer xxx_abc"}))
BearerCheck()(_ctx("xxx_abc"))
@patch("controllers.openapi.auth.steps.get_authenticator")
def test_bearer_check_populates_context_and_g_auth_ctx(get_auth):
def test_bearer_check_populates_context_and_publishes_auth_ctx(get_auth):
tok_id = uuid.uuid4()
authn = AuthContext(
subject_type=SubjectType.ACCOUNT,
@ -50,18 +55,29 @@ def test_bearer_check_populates_context_and_g_auth_ctx(get_auth):
get_auth.return_value.authenticate.return_value = authn
app = Flask(__name__)
ctx = _ctx({"Authorization": "Bearer dfoa_abc"})
ctx = _ctx("dfoa_abc")
with app.test_request_context():
BearerCheck()(ctx)
assert ctx.subject_type == SubjectType.ACCOUNT
assert ctx.subject_email == "a@x.com"
assert ctx.scopes == frozenset({Scope.FULL})
assert ctx.source == "oauth-account"
assert ctx.token_id == tok_id
assert ctx.token_hash == "hash-1"
# BearerCheck must also publish the same identity on `g.auth_ctx`
# so the surface gate + downstream handlers don't see two
# different identity sources between the decorator + pipeline paths.
assert g.auth_ctx is authn
assert g.auth_ctx.client_id == "difyctl"
try:
assert ctx.subject_type == SubjectType.ACCOUNT
assert ctx.subject_email == "a@x.com"
assert ctx.scopes == frozenset({Scope.FULL})
assert ctx.source == "oauth-account"
assert ctx.token_id == tok_id
assert ctx.token_hash == "hash-1"
# BearerCheck must also publish the same identity on the
# openapi auth ContextVar so the surface gate + downstream
# handlers don't see two different identity sources between
# the decorator + pipeline paths. The reset token is parked
# on `ctx.auth_ctx_reset_token` for `Pipeline.guard` to
# consume in its `finally`.
published = try_get_auth_ctx()
assert published is authn
assert published.client_id == "difyctl"
assert ctx.auth_ctx_reset_token is not None
finally:
# In production `Pipeline.guard` resets the ContextVar; in
# this isolated step-level test we reset it ourselves so the
# value doesn't leak into the next test on the same worker.
assert ctx.auth_ctx_reset_token is not None
reset_auth_ctx(ctx.auth_ctx_reset_token)

View File

@ -15,7 +15,7 @@ from libs.oauth_bearer import SubjectType
def _ctx(*, subject_type, account_id, tenant_id, cached_verified_tenants=None, token_hash=None) -> Context:
c = Context(request=MagicMock(), required_scope="apps:read")
c = Context(required_scope="apps:read")
c.subject_type = subject_type
c.account_id = account_id
c.tenant = SimpleNamespace(id=tenant_id) if tenant_id else None

View File

@ -1,5 +1,5 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
from werkzeug.exceptions import Unauthorized
@ -12,7 +12,7 @@ from libs.oauth_bearer import SubjectType
def _ctx(*, subject_type, account_id=None, subject_email=None):
c = Context(request=MagicMock(), required_scope="apps:run")
c = Context(required_scope="apps:run")
c.subject_type = subject_type
c.account_id = account_id
c.subject_email = subject_email

View File

@ -1,5 +1,3 @@
from unittest.mock import MagicMock
import pytest
from werkzeug.exceptions import Forbidden
@ -8,7 +6,7 @@ from controllers.openapi.auth.steps import ScopeCheck
def _ctx(scopes, required):
c = Context(request=MagicMock(), required_scope=required)
c = Context(required_scope=required)
c.scopes = frozenset(scopes)
return c

View File

@ -5,24 +5,40 @@ pipeline step (`SurfaceCheck`) — and both must:
- 403 on mismatched subject type with a canonical-path hint
- emit `openapi.wrong_surface_denied` once with the right payload
- pass-through on match
- raise RuntimeError (not 403) if g.auth_ctx is missing — that's a
wiring bug, not a user-driven failure
- raise RuntimeError (not 403) if the auth ContextVar is unset — that's
a wiring bug, not a user-driven failure
Identity is published via `libs.oauth_bearer.set_auth_ctx` / read with
`try_get_auth_ctx`. Tests wrap the publish in a `_publish_auth_ctx`
context manager so the ContextVar resets even when an assertion fails;
that keeps state from leaking into the next test on the same worker.
"""
from __future__ import annotations
import uuid
from collections.abc import Iterator
from contextlib import contextmanager
from datetime import UTC, datetime
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
from flask import Flask, g
from flask import Flask
from werkzeug.exceptions import Forbidden
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import SurfaceCheck
from controllers.openapi.auth.surface_gate import _coerce_subject_type, accept_subjects, check_surface
from libs.oauth_bearer import AuthContext, Scope, SubjectType
from libs.oauth_bearer import AuthContext, Scope, SubjectType, reset_auth_ctx, set_auth_ctx
@contextmanager
def _publish_auth_ctx(ctx: AuthContext) -> Iterator[None]:
token = set_auth_ctx(ctx)
try:
yield
finally:
reset_auth_ctx(token)
def _account_ctx() -> AuthContext:
@ -64,15 +80,13 @@ def _sso_ctx() -> AuthContext:
def test_check_surface_passes_when_subject_in_accepted():
app = Flask(__name__)
with app.test_request_context("/openapi/v1/apps"):
g.auth_ctx = _account_ctx()
with app.test_request_context("/openapi/v1/apps"), _publish_auth_ctx(_account_ctx()):
check_surface(frozenset({SubjectType.ACCOUNT})) # no raise
def test_check_surface_rejects_on_wrong_subject_and_emits_audit():
app = Flask(__name__)
with app.test_request_context("/openapi/v1/permitted-external-apps"):
g.auth_ctx = _account_ctx()
with app.test_request_context("/openapi/v1/permitted-external-apps"), _publish_auth_ctx(_account_ctx()):
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
with pytest.raises(Forbidden) as exc:
check_surface(frozenset({SubjectType.EXTERNAL_SSO}))
@ -90,8 +104,7 @@ def test_check_surface_rejects_on_wrong_subject_and_emits_audit():
def test_check_surface_rejects_sso_on_account_surface():
app = Flask(__name__)
with app.test_request_context("/openapi/v1/apps"):
g.auth_ctx = _sso_ctx()
with app.test_request_context("/openapi/v1/apps"), _publish_auth_ctx(_sso_ctx()):
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
with pytest.raises(Forbidden):
check_surface(frozenset({SubjectType.ACCOUNT}))
@ -99,11 +112,12 @@ def test_check_surface_rejects_sso_on_account_surface():
assert kwargs["subject_type"] == SubjectType.EXTERNAL_SSO.value
def test_check_surface_runtime_error_when_g_auth_ctx_missing():
"""Missing g.auth_ctx means the bearer layer didn't run — wiring bug,
not a user-driven failure. Surface as RuntimeError (loud) so a future
refactor doesn't accidentally let a route skip authentication and
return a 403 that looks identical to a legitimate wrong-surface deny.
def test_check_surface_runtime_error_when_auth_ctx_missing():
"""Missing auth ContextVar means the bearer layer didn't run — wiring
bug, not a user-driven failure. Surface as RuntimeError (loud) so a
future refactor doesn't accidentally let a route skip authentication
and return a 403 that looks identical to a legitimate wrong-surface
deny.
"""
app = Flask(__name__)
with app.test_request_context("/openapi/v1/apps"):
@ -134,8 +148,7 @@ def _make_app() -> Flask:
def test_accept_subjects_decorator_passes_on_match():
app = _make_app()
with app.test_request_context("/account-only"):
g.auth_ctx = _account_ctx()
with app.test_request_context("/account-only"), _publish_auth_ctx(_account_ctx()):
# Re-route through the decorated function by reaching for view_function
view = app.view_functions["_account_only"]
assert view() == "ok"
@ -143,8 +156,7 @@ def test_accept_subjects_decorator_passes_on_match():
def test_accept_subjects_decorator_403_on_miss():
app = _make_app()
with app.test_request_context("/external-only"):
g.auth_ctx = _account_ctx()
with app.test_request_context("/external-only"), _publish_auth_ctx(_account_ctx()):
view = app.view_functions["_external_only"]
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface"):
with pytest.raises(Forbidden):
@ -157,24 +169,22 @@ def test_accept_subjects_decorator_403_on_miss():
def _pipeline_ctx() -> Context:
req = MagicMock()
req.path = "/openapi/v1/apps/<id>/run"
return Context(request=req, required_scope=Scope.APPS_RUN)
# SurfaceCheck reads ``request.path`` from Flask's global request — set up
# via ``app.test_request_context`` in the calling tests — not from Context.
return Context(required_scope=Scope.APPS_RUN)
def test_surface_check_passes_on_match():
step = SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT}))
app = Flask(__name__)
with app.test_request_context("/openapi/v1/apps/x/run"):
g.auth_ctx = _account_ctx()
with app.test_request_context("/openapi/v1/apps/x/run"), _publish_auth_ctx(_account_ctx()):
step(_pipeline_ctx()) # no raise
def test_surface_check_rejects_on_miss_and_emits_audit():
step = SurfaceCheck(accepted=frozenset({SubjectType.EXTERNAL_SSO}))
app = Flask(__name__)
with app.test_request_context("/openapi/v1/apps/x/run"):
g.auth_ctx = _account_ctx()
with app.test_request_context("/openapi/v1/apps/x/run"), _publish_auth_ctx(_account_ctx()):
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
with pytest.raises(Forbidden):
step(_pipeline_ctx())