mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 00:48:04 +08:00
test(api): isolate login decorator test patches
This commit is contained in:
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import importlib
|
||||
from contextlib import contextmanager
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@ -18,7 +18,6 @@ if not hasattr(builtins, "MethodView"):
|
||||
|
||||
_CONTROLLER_MODULE: ModuleType | None = None
|
||||
_WRAPS_MODULE: ModuleType | None = None
|
||||
_CONTROLLER_PATCHERS: list[patch] = []
|
||||
|
||||
|
||||
@contextmanager
|
||||
@ -37,6 +36,14 @@ def app() -> Flask:
|
||||
|
||||
@pytest.fixture
|
||||
def controller_module(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Import the controller with auth decorators neutralized only during import.
|
||||
|
||||
The imported view classes retain those no-op decorators after import, so we
|
||||
can restore the original globals immediately and avoid leaking auth patches
|
||||
into unrelated tests such as libs.login unit coverage.
|
||||
"""
|
||||
|
||||
module_name = "controllers.console.workspace.tool_providers"
|
||||
global _CONTROLLER_MODULE
|
||||
if _CONTROLLER_MODULE is None:
|
||||
@ -51,13 +58,12 @@ def controller_module(monkeypatch: pytest.MonkeyPatch):
|
||||
("controllers.console.wraps.is_admin_or_owner_required", _noop),
|
||||
("controllers.console.wraps.enterprise_license_required", _noop),
|
||||
]
|
||||
for target, value in patch_targets:
|
||||
patcher = patch(target, value)
|
||||
patcher.start()
|
||||
_CONTROLLER_PATCHERS.append(patcher)
|
||||
monkeypatch.setenv("DIFY_SETUP_READY", "true")
|
||||
with _mock_db():
|
||||
_CONTROLLER_MODULE = importlib.import_module(module_name)
|
||||
with ExitStack() as stack:
|
||||
for target, value in patch_targets:
|
||||
stack.enter_context(patch(target, value))
|
||||
with _mock_db():
|
||||
_CONTROLLER_MODULE = importlib.import_module(module_name)
|
||||
|
||||
module = _CONTROLLER_MODULE
|
||||
monkeypatch.setattr(module, "jsonable_encoder", lambda payload: payload)
|
||||
|
||||
@ -11,6 +11,17 @@ from libs.login import current_user
|
||||
from models.account import Account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def protected_view():
|
||||
"""Build a small login-protected view that exercises the decorator logic."""
|
||||
|
||||
@login_module.login_required
|
||||
def _protected_view():
|
||||
return "Protected content"
|
||||
|
||||
return _protected_view
|
||||
|
||||
|
||||
class MockUser(UserMixin):
|
||||
"""Mock user class for testing."""
|
||||
|
||||
@ -24,13 +35,13 @@ class MockUser(UserMixin):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def login_app() -> Flask:
|
||||
def login_app(mocker: MockerFixture) -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
|
||||
login_manager = LoginManager()
|
||||
login_manager.init_app(app)
|
||||
login_manager.unauthorized = MagicMock(return_value="Unauthorized")
|
||||
login_manager.unauthorized = mocker.Mock(name="unauthorized", return_value="Unauthorized")
|
||||
|
||||
@login_manager.user_loader
|
||||
def load_user(_user_id: str):
|
||||
@ -49,20 +60,28 @@ def csrf_check(mocker: MockerFixture) -> MagicMock:
|
||||
return mocker.patch.object(login_module, "check_csrf_token")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def resolve_current_user(mocker: MockerFixture):
|
||||
def _patch(user: MockUser | Account | None) -> MagicMock:
|
||||
return mocker.patch.object(login_module, "_resolve_current_user", return_value=user)
|
||||
|
||||
return _patch
|
||||
|
||||
|
||||
class TestLoginRequired:
|
||||
"""Test cases for login_required decorator."""
|
||||
|
||||
def test_authenticated_user_can_access_protected_view(
|
||||
self, login_app: Flask, csrf_check: MagicMock, mocker: MockerFixture
|
||||
self,
|
||||
login_app: Flask,
|
||||
protected_view,
|
||||
csrf_check: MagicMock,
|
||||
resolve_current_user,
|
||||
):
|
||||
"""Test that authenticated users can access protected views."""
|
||||
|
||||
@login_module.login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
resolve_user = mocker.patch.object(login_module, "_resolve_current_user", return_value=mock_user)
|
||||
resolve_user = resolve_current_user(mock_user)
|
||||
|
||||
with login_app.test_request_context():
|
||||
result = protected_view()
|
||||
@ -84,18 +103,15 @@ class TestLoginRequired:
|
||||
def test_unauthorized_access_returns_login_manager_response(
|
||||
self,
|
||||
login_app: Flask,
|
||||
protected_view,
|
||||
csrf_check: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
resolve_current_user,
|
||||
resolved_user: MockUser | None,
|
||||
description: str,
|
||||
):
|
||||
"""Test that missing or unauthenticated users are redirected."""
|
||||
|
||||
@login_module.login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
resolve_user = mocker.patch.object(login_module, "_resolve_current_user", return_value=resolved_user)
|
||||
resolve_user = resolve_current_user(resolved_user)
|
||||
|
||||
with login_app.test_request_context():
|
||||
result = protected_view()
|
||||
@ -115,19 +131,16 @@ class TestLoginRequired:
|
||||
def test_bypass_paths_skip_authentication_and_csrf(
|
||||
self,
|
||||
login_app: Flask,
|
||||
protected_view,
|
||||
csrf_check: MagicMock,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mocker: MockerFixture,
|
||||
resolve_current_user,
|
||||
method: str,
|
||||
login_disabled: bool,
|
||||
):
|
||||
"""Test that bypass conditions skip auth lookup, CSRF, and unauthorized handling."""
|
||||
|
||||
@login_module.login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
resolve_user = mocker.patch.object(login_module, "_resolve_current_user")
|
||||
resolve_user = resolve_current_user(MockUser("test_user"))
|
||||
monkeypatch.setattr(login_module.dify_config, "LOGIN_DISABLED", login_disabled)
|
||||
|
||||
with login_app.test_request_context(method=method):
|
||||
@ -151,20 +164,20 @@ class TestGetUser:
|
||||
assert user == mock_user
|
||||
assert user.id == "test_user"
|
||||
|
||||
def test_get_user_loads_user_if_not_in_g(self, login_app: Flask):
|
||||
def test_get_user_loads_user_if_not_in_g(self, login_app: Flask, mocker: MockerFixture):
|
||||
"""Test that _get_user loads user if not already in g."""
|
||||
mock_user = MockUser("test_user")
|
||||
|
||||
def _load_user() -> None:
|
||||
g._login_user = mock_user
|
||||
|
||||
login_app.login_manager._load_user = MagicMock(side_effect=_load_user)
|
||||
load_user = mocker.patch.object(login_app.login_manager, "_load_user", side_effect=_load_user)
|
||||
|
||||
with login_app.test_request_context():
|
||||
user = login_module._get_user()
|
||||
|
||||
assert user == mock_user
|
||||
login_app.login_manager._load_user.assert_called_once_with()
|
||||
load_user.assert_called_once_with()
|
||||
|
||||
def test_get_user_returns_none_without_request_context(self):
|
||||
"""Test that _get_user returns None outside request context."""
|
||||
@ -199,7 +212,7 @@ class TestCurrentAccountWithTenant:
|
||||
def test_returns_account_and_tenant_id(self, mocker: MockerFixture):
|
||||
account = Account(name="Test User", email="test@example.com")
|
||||
account._current_tenant = SimpleNamespace(id="tenant-123")
|
||||
current_user_proxy = MagicMock()
|
||||
current_user_proxy = mocker.Mock()
|
||||
current_user_proxy._get_current_object.return_value = account
|
||||
mocker.patch.object(login_module, "current_user", new=current_user_proxy)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user