Files
dify/api/tests/unit_tests/services/test_oauth_device_flow.py
2026-05-23 14:21:04 +08:00

147 lines
4.7 KiB
Python

from __future__ import annotations
import uuid
from unittest.mock import MagicMock
from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT, AuthContext, SubjectType
from services.oauth_device_flow import (
list_active_sessions,
revoke_oauth_token,
subject_match_clauses,
token_belongs_to_subject,
)
def _account_ctx() -> AuthContext:
return AuthContext(
subject_type=SubjectType.ACCOUNT,
subject_email="user@example.com",
subject_issuer="dify:account",
account_id=uuid.uuid4(),
client_id="difyctl",
scopes=frozenset({"full"}),
token_id=uuid.uuid4(),
source="oauth_account",
expires_at=None,
token_hash="h1",
verified_tenants={},
)
def _sso_ctx() -> AuthContext:
return AuthContext(
subject_type=SubjectType.EXTERNAL_SSO,
subject_email="sso@partner.com",
subject_issuer="https://idp.partner.com",
account_id=None,
client_id="difyctl",
scopes=frozenset({"apps:run"}),
token_id=uuid.uuid4(),
source="oauth_external_sso",
expires_at=None,
token_hash="h1",
verified_tenants={},
)
# ---------------------------------------------------------------------------
# subject_match_clauses
# ---------------------------------------------------------------------------
def test_subject_match_clauses_account_matches_only_account_id():
clauses = subject_match_clauses(_account_ctx())
assert len(clauses) == 1
assert "account_id" in str(clauses[0])
def test_subject_match_clauses_external_sso_requires_null_account_id():
"""External SSO must additionally require ``account_id IS NULL`` so a
same-email account-flow row from a federated tenant cannot be
enumerated/revoked through an SSO bearer.
"""
clauses = subject_match_clauses(_sso_ctx())
assert len(clauses) == 3
rendered = " ".join(str(c) for c in clauses)
assert "subject_email" in rendered
assert "subject_issuer" in rendered
assert "account_id IS NULL" in rendered
# ---------------------------------------------------------------------------
# revoke_oauth_token
# ---------------------------------------------------------------------------
def test_revoke_oauth_token_invalidates_redis_cache_when_live_hash_seen():
"""Happy path: snapshot finds a live ``token_hash`` → UPDATE runs +
Redis cache entry is DEL'd so the next bearer probe re-reads the now
revoked row from DB.
"""
session = MagicMock()
session.query.return_value.filter.return_value.one_or_none.return_value = ("live-hash",)
redis = MagicMock()
revoke_oauth_token(session, redis, "token-id")
assert session.execute.called # UPDATE ... WHERE revoked_at IS NULL
assert session.commit.called
redis.delete.assert_called_once_with(TOKEN_CACHE_KEY_FMT.format(hash="live-hash"))
def test_revoke_oauth_token_is_idempotent_when_already_revoked():
"""Second call (or race-loser): no live hash → UPDATE still runs (it
is itself idempotent thanks to ``WHERE revoked_at IS NULL``) but the
Redis invalidation is skipped because there's no cache entry to
drop.
"""
session = MagicMock()
session.query.return_value.filter.return_value.one_or_none.return_value = None
redis = MagicMock()
revoke_oauth_token(session, redis, "token-id")
assert session.execute.called
assert session.commit.called
redis.delete.assert_not_called()
# ---------------------------------------------------------------------------
# list_active_sessions / token_belongs_to_subject
# ---------------------------------------------------------------------------
def test_list_active_sessions_returns_session_execute_rows():
"""Thin delegation: the helper materialises whatever
``session.execute(...).scalars().all()`` returns into a list. The
``.scalars()`` step unwraps each one-element ``Row`` so callers see
bare ``OAuthAccessToken`` entities (matches the declared return
type).
"""
from datetime import UTC, datetime
session = MagicMock()
fake_rows = [MagicMock(), MagicMock()]
session.execute.return_value.scalars.return_value.all.return_value = fake_rows
out = list_active_sessions(session, _account_ctx(), datetime.now(UTC))
assert out == fake_rows
assert session.execute.called
def test_token_belongs_to_subject_true_when_row_present():
session = MagicMock()
session.execute.return_value.first.return_value = ("some-id",)
assert token_belongs_to_subject(session, "token-id", _account_ctx()) is True
def test_token_belongs_to_subject_false_when_no_row():
session = MagicMock()
session.execute.return_value.first.return_value = None
assert token_belongs_to_subject(session, "token-id", _account_ctx()) is False