mirror of
https://github.com/langgenius/dify.git
synced 2026-06-08 09:27:39 +08:00
refactor: decouple Context from flask
This commit is contained in:
@ -1,10 +1,10 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
|
||||
|
||||
def test_context_starts_unpopulated():
|
||||
ctx = Context(request=MagicMock(), required_scope="apps:run")
|
||||
ctx = Context(required_scope="apps:run")
|
||||
assert ctx.bearer_token is None
|
||||
assert ctx.path_params == {}
|
||||
assert ctx.subject_type is None
|
||||
assert ctx.subject_email is None
|
||||
assert ctx.account_id is None
|
||||
@ -16,6 +16,6 @@ def test_context_starts_unpopulated():
|
||||
|
||||
|
||||
def test_context_fields_are_mutable():
|
||||
ctx = Context(request=MagicMock(), required_scope="apps:run")
|
||||
ctx = Context(required_scope="apps:run")
|
||||
ctx.scopes = frozenset({"full"})
|
||||
assert "full" in ctx.scopes
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
@ -17,7 +15,7 @@ def test_run_invokes_each_step_in_order():
|
||||
def __call__(self, ctx):
|
||||
calls.append(self.tag)
|
||||
|
||||
Pipeline(S("a"), S("b"), S("c")).run(Context(request=MagicMock(), required_scope="x"))
|
||||
Pipeline(S("a"), S("b"), S("c")).run(Context(required_scope="x"))
|
||||
assert calls == ["a", "b", "c"]
|
||||
|
||||
|
||||
@ -33,7 +31,7 @@ def test_run_short_circuits_on_raise():
|
||||
calls.append("ran")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
Pipeline(Boom(), Tail()).run(Context(request=MagicMock(), required_scope="x"))
|
||||
Pipeline(Boom(), Tail()).run(Context(required_scope="x"))
|
||||
assert calls == []
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
@ -9,10 +9,8 @@ from controllers.openapi.auth.steps import AppResolver
|
||||
from models import TenantStatus
|
||||
|
||||
|
||||
def _ctx(view_args):
|
||||
req = MagicMock()
|
||||
req.view_args = view_args
|
||||
return Context(request=req, required_scope="apps:run")
|
||||
def _ctx(path_params: dict[str, str] | None) -> Context:
|
||||
return Context(required_scope="apps:run", path_params=path_params or {})
|
||||
|
||||
|
||||
def _app(*, status="normal", enable_api=True):
|
||||
@ -28,7 +26,9 @@ def test_resolver_rejects_missing_path_param():
|
||||
AppResolver()(_ctx({}))
|
||||
|
||||
|
||||
def test_resolver_rejects_none_view_args():
|
||||
def test_resolver_rejects_empty_path_params():
|
||||
# `Pipeline.guard` always seeds an empty dict when Flask reports no
|
||||
# view args, so a missing `app_id` key surfaces here as BadRequest.
|
||||
with pytest.raises(BadRequest):
|
||||
AppResolver()(_ctx(None))
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
@ -11,7 +11,7 @@ from libs.oauth_bearer import SubjectType
|
||||
|
||||
|
||||
def _ctx(*, subject_type, account_id="acc1"):
|
||||
c = Context(request=MagicMock(), required_scope="apps:run")
|
||||
c = Context(required_scope="apps:run")
|
||||
c.subject_type = subject_type
|
||||
c.subject_email = "alice@example.com"
|
||||
c.account_id = account_id
|
||||
|
||||
@ -1,26 +1,31 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import BearerCheck
|
||||
from libs.oauth_bearer import AuthContext, InvalidBearerError, Scope, SubjectType
|
||||
from libs.oauth_bearer import (
|
||||
AuthContext,
|
||||
InvalidBearerError,
|
||||
Scope,
|
||||
SubjectType,
|
||||
reset_auth_ctx,
|
||||
try_get_auth_ctx,
|
||||
)
|
||||
|
||||
|
||||
def _ctx(headers):
|
||||
req = MagicMock()
|
||||
req.headers = headers
|
||||
return Context(request=req, required_scope="apps:run")
|
||||
def _ctx(bearer_token: str | None) -> Context:
|
||||
return Context(required_scope="apps:run", bearer_token=bearer_token)
|
||||
|
||||
|
||||
def test_bearer_check_rejects_missing_header():
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context(), pytest.raises(Unauthorized):
|
||||
BearerCheck()(_ctx({}))
|
||||
BearerCheck()(_ctx(None))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.get_authenticator")
|
||||
@ -28,11 +33,11 @@ def test_bearer_check_rejects_unknown_prefix(get_auth):
|
||||
get_auth.return_value.authenticate.side_effect = InvalidBearerError("unknown token prefix")
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context(), pytest.raises(Unauthorized):
|
||||
BearerCheck()(_ctx({"Authorization": "Bearer xxx_abc"}))
|
||||
BearerCheck()(_ctx("xxx_abc"))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.get_authenticator")
|
||||
def test_bearer_check_populates_context_and_g_auth_ctx(get_auth):
|
||||
def test_bearer_check_populates_context_and_publishes_auth_ctx(get_auth):
|
||||
tok_id = uuid.uuid4()
|
||||
authn = AuthContext(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
@ -50,18 +55,29 @@ def test_bearer_check_populates_context_and_g_auth_ctx(get_auth):
|
||||
get_auth.return_value.authenticate.return_value = authn
|
||||
|
||||
app = Flask(__name__)
|
||||
ctx = _ctx({"Authorization": "Bearer dfoa_abc"})
|
||||
ctx = _ctx("dfoa_abc")
|
||||
with app.test_request_context():
|
||||
BearerCheck()(ctx)
|
||||
|
||||
assert ctx.subject_type == SubjectType.ACCOUNT
|
||||
assert ctx.subject_email == "a@x.com"
|
||||
assert ctx.scopes == frozenset({Scope.FULL})
|
||||
assert ctx.source == "oauth-account"
|
||||
assert ctx.token_id == tok_id
|
||||
assert ctx.token_hash == "hash-1"
|
||||
# BearerCheck must also publish the same identity on `g.auth_ctx`
|
||||
# so the surface gate + downstream handlers don't see two
|
||||
# different identity sources between the decorator + pipeline paths.
|
||||
assert g.auth_ctx is authn
|
||||
assert g.auth_ctx.client_id == "difyctl"
|
||||
try:
|
||||
assert ctx.subject_type == SubjectType.ACCOUNT
|
||||
assert ctx.subject_email == "a@x.com"
|
||||
assert ctx.scopes == frozenset({Scope.FULL})
|
||||
assert ctx.source == "oauth-account"
|
||||
assert ctx.token_id == tok_id
|
||||
assert ctx.token_hash == "hash-1"
|
||||
# BearerCheck must also publish the same identity on the
|
||||
# openapi auth ContextVar so the surface gate + downstream
|
||||
# handlers don't see two different identity sources between
|
||||
# the decorator + pipeline paths. The reset token is parked
|
||||
# on `ctx.auth_ctx_reset_token` for `Pipeline.guard` to
|
||||
# consume in its `finally`.
|
||||
published = try_get_auth_ctx()
|
||||
assert published is authn
|
||||
assert published.client_id == "difyctl"
|
||||
assert ctx.auth_ctx_reset_token is not None
|
||||
finally:
|
||||
# In production `Pipeline.guard` resets the ContextVar; in
|
||||
# this isolated step-level test we reset it ourselves so the
|
||||
# value doesn't leak into the next test on the same worker.
|
||||
assert ctx.auth_ctx_reset_token is not None
|
||||
reset_auth_ctx(ctx.auth_ctx_reset_token)
|
||||
|
||||
@ -15,7 +15,7 @@ from libs.oauth_bearer import SubjectType
|
||||
|
||||
|
||||
def _ctx(*, subject_type, account_id, tenant_id, cached_verified_tenants=None, token_hash=None) -> Context:
|
||||
c = Context(request=MagicMock(), required_scope="apps:read")
|
||||
c = Context(required_scope="apps:read")
|
||||
c.subject_type = subject_type
|
||||
c.account_id = account_id
|
||||
c.tenant = SimpleNamespace(id=tenant_id) if tenant_id else None
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
@ -12,7 +12,7 @@ from libs.oauth_bearer import SubjectType
|
||||
|
||||
|
||||
def _ctx(*, subject_type, account_id=None, subject_email=None):
|
||||
c = Context(request=MagicMock(), required_scope="apps:run")
|
||||
c = Context(required_scope="apps:run")
|
||||
c.subject_type = subject_type
|
||||
c.account_id = account_id
|
||||
c.subject_email = subject_email
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
@ -8,7 +6,7 @@ from controllers.openapi.auth.steps import ScopeCheck
|
||||
|
||||
|
||||
def _ctx(scopes, required):
|
||||
c = Context(request=MagicMock(), required_scope=required)
|
||||
c = Context(required_scope=required)
|
||||
c.scopes = frozenset(scopes)
|
||||
return c
|
||||
|
||||
|
||||
@ -5,24 +5,40 @@ pipeline step (`SurfaceCheck`) — and both must:
|
||||
- 403 on mismatched subject type with a canonical-path hint
|
||||
- emit `openapi.wrong_surface_denied` once with the right payload
|
||||
- pass-through on match
|
||||
- raise RuntimeError (not 403) if g.auth_ctx is missing — that's a
|
||||
wiring bug, not a user-driven failure
|
||||
- raise RuntimeError (not 403) if the auth ContextVar is unset — that's
|
||||
a wiring bug, not a user-driven failure
|
||||
|
||||
Identity is published via `libs.oauth_bearer.set_auth_ctx` / read with
|
||||
`try_get_auth_ctx`. Tests wrap the publish in a `_publish_auth_ctx`
|
||||
context manager so the ContextVar resets even when an assertion fails;
|
||||
that keeps state from leaking into the next test on the same worker.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import SurfaceCheck
|
||||
from controllers.openapi.auth.surface_gate import _coerce_subject_type, accept_subjects, check_surface
|
||||
from libs.oauth_bearer import AuthContext, Scope, SubjectType
|
||||
from libs.oauth_bearer import AuthContext, Scope, SubjectType, reset_auth_ctx, set_auth_ctx
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _publish_auth_ctx(ctx: AuthContext) -> Iterator[None]:
|
||||
token = set_auth_ctx(ctx)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
reset_auth_ctx(token)
|
||||
|
||||
|
||||
def _account_ctx() -> AuthContext:
|
||||
@ -64,15 +80,13 @@ def _sso_ctx() -> AuthContext:
|
||||
|
||||
def test_check_surface_passes_when_subject_in_accepted():
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/apps"):
|
||||
g.auth_ctx = _account_ctx()
|
||||
with app.test_request_context("/openapi/v1/apps"), _publish_auth_ctx(_account_ctx()):
|
||||
check_surface(frozenset({SubjectType.ACCOUNT})) # no raise
|
||||
|
||||
|
||||
def test_check_surface_rejects_on_wrong_subject_and_emits_audit():
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/permitted-external-apps"):
|
||||
g.auth_ctx = _account_ctx()
|
||||
with app.test_request_context("/openapi/v1/permitted-external-apps"), _publish_auth_ctx(_account_ctx()):
|
||||
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
|
||||
with pytest.raises(Forbidden) as exc:
|
||||
check_surface(frozenset({SubjectType.EXTERNAL_SSO}))
|
||||
@ -90,8 +104,7 @@ def test_check_surface_rejects_on_wrong_subject_and_emits_audit():
|
||||
|
||||
def test_check_surface_rejects_sso_on_account_surface():
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/apps"):
|
||||
g.auth_ctx = _sso_ctx()
|
||||
with app.test_request_context("/openapi/v1/apps"), _publish_auth_ctx(_sso_ctx()):
|
||||
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
|
||||
with pytest.raises(Forbidden):
|
||||
check_surface(frozenset({SubjectType.ACCOUNT}))
|
||||
@ -99,11 +112,12 @@ def test_check_surface_rejects_sso_on_account_surface():
|
||||
assert kwargs["subject_type"] == SubjectType.EXTERNAL_SSO.value
|
||||
|
||||
|
||||
def test_check_surface_runtime_error_when_g_auth_ctx_missing():
|
||||
"""Missing g.auth_ctx means the bearer layer didn't run — wiring bug,
|
||||
not a user-driven failure. Surface as RuntimeError (loud) so a future
|
||||
refactor doesn't accidentally let a route skip authentication and
|
||||
return a 403 that looks identical to a legitimate wrong-surface deny.
|
||||
def test_check_surface_runtime_error_when_auth_ctx_missing():
|
||||
"""Missing auth ContextVar means the bearer layer didn't run — wiring
|
||||
bug, not a user-driven failure. Surface as RuntimeError (loud) so a
|
||||
future refactor doesn't accidentally let a route skip authentication
|
||||
and return a 403 that looks identical to a legitimate wrong-surface
|
||||
deny.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/apps"):
|
||||
@ -134,8 +148,7 @@ def _make_app() -> Flask:
|
||||
|
||||
def test_accept_subjects_decorator_passes_on_match():
|
||||
app = _make_app()
|
||||
with app.test_request_context("/account-only"):
|
||||
g.auth_ctx = _account_ctx()
|
||||
with app.test_request_context("/account-only"), _publish_auth_ctx(_account_ctx()):
|
||||
# Re-route through the decorated function by reaching for view_function
|
||||
view = app.view_functions["_account_only"]
|
||||
assert view() == "ok"
|
||||
@ -143,8 +156,7 @@ def test_accept_subjects_decorator_passes_on_match():
|
||||
|
||||
def test_accept_subjects_decorator_403_on_miss():
|
||||
app = _make_app()
|
||||
with app.test_request_context("/external-only"):
|
||||
g.auth_ctx = _account_ctx()
|
||||
with app.test_request_context("/external-only"), _publish_auth_ctx(_account_ctx()):
|
||||
view = app.view_functions["_external_only"]
|
||||
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface"):
|
||||
with pytest.raises(Forbidden):
|
||||
@ -157,24 +169,22 @@ def test_accept_subjects_decorator_403_on_miss():
|
||||
|
||||
|
||||
def _pipeline_ctx() -> Context:
|
||||
req = MagicMock()
|
||||
req.path = "/openapi/v1/apps/<id>/run"
|
||||
return Context(request=req, required_scope=Scope.APPS_RUN)
|
||||
# SurfaceCheck reads ``request.path`` from Flask's global request — set up
|
||||
# via ``app.test_request_context`` in the calling tests — not from Context.
|
||||
return Context(required_scope=Scope.APPS_RUN)
|
||||
|
||||
|
||||
def test_surface_check_passes_on_match():
|
||||
step = SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT}))
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/apps/x/run"):
|
||||
g.auth_ctx = _account_ctx()
|
||||
with app.test_request_context("/openapi/v1/apps/x/run"), _publish_auth_ctx(_account_ctx()):
|
||||
step(_pipeline_ctx()) # no raise
|
||||
|
||||
|
||||
def test_surface_check_rejects_on_miss_and_emits_audit():
|
||||
step = SurfaceCheck(accepted=frozenset({SubjectType.EXTERNAL_SSO}))
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/apps/x/run"):
|
||||
g.auth_ctx = _account_ctx()
|
||||
with app.test_request_context("/openapi/v1/apps/x/run"), _publish_auth_ctx(_account_ctx()):
|
||||
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
|
||||
with pytest.raises(Forbidden):
|
||||
step(_pipeline_ctx())
|
||||
|
||||
Reference in New Issue
Block a user