mirror of
https://github.com/langgenius/dify.git
synced 2026-05-26 11:57:40 +08:00
147 lines
4.7 KiB
Python
147 lines
4.7 KiB
Python
from __future__ import annotations
|
|
|
|
import uuid
|
|
from unittest.mock import MagicMock
|
|
|
|
from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT, AuthContext, SubjectType
|
|
from services.oauth_device_flow import (
|
|
list_active_sessions,
|
|
revoke_oauth_token,
|
|
subject_match_clauses,
|
|
token_belongs_to_subject,
|
|
)
|
|
|
|
|
|
def _account_ctx() -> AuthContext:
|
|
return AuthContext(
|
|
subject_type=SubjectType.ACCOUNT,
|
|
subject_email="user@example.com",
|
|
subject_issuer="dify:account",
|
|
account_id=uuid.uuid4(),
|
|
client_id="difyctl",
|
|
scopes=frozenset({"full"}),
|
|
token_id=uuid.uuid4(),
|
|
source="oauth_account",
|
|
expires_at=None,
|
|
token_hash="h1",
|
|
verified_tenants={},
|
|
)
|
|
|
|
|
|
def _sso_ctx() -> AuthContext:
|
|
return AuthContext(
|
|
subject_type=SubjectType.EXTERNAL_SSO,
|
|
subject_email="sso@partner.com",
|
|
subject_issuer="https://idp.partner.com",
|
|
account_id=None,
|
|
client_id="difyctl",
|
|
scopes=frozenset({"apps:run"}),
|
|
token_id=uuid.uuid4(),
|
|
source="oauth_external_sso",
|
|
expires_at=None,
|
|
token_hash="h1",
|
|
verified_tenants={},
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# subject_match_clauses
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_subject_match_clauses_account_matches_only_account_id():
|
|
clauses = subject_match_clauses(_account_ctx())
|
|
assert len(clauses) == 1
|
|
assert "account_id" in str(clauses[0])
|
|
|
|
|
|
def test_subject_match_clauses_external_sso_requires_null_account_id():
|
|
"""External SSO must additionally require ``account_id IS NULL`` so a
|
|
same-email account-flow row from a federated tenant cannot be
|
|
enumerated/revoked through an SSO bearer.
|
|
"""
|
|
clauses = subject_match_clauses(_sso_ctx())
|
|
assert len(clauses) == 3
|
|
rendered = " ".join(str(c) for c in clauses)
|
|
assert "subject_email" in rendered
|
|
assert "subject_issuer" in rendered
|
|
assert "account_id IS NULL" in rendered
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# revoke_oauth_token
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_revoke_oauth_token_invalidates_redis_cache_when_live_hash_seen():
|
|
"""Happy path: snapshot finds a live ``token_hash`` → UPDATE runs +
|
|
Redis cache entry is DEL'd so the next bearer probe re-reads the now
|
|
revoked row from DB.
|
|
"""
|
|
session = MagicMock()
|
|
session.query.return_value.filter.return_value.one_or_none.return_value = ("live-hash",)
|
|
|
|
redis = MagicMock()
|
|
|
|
revoke_oauth_token(session, redis, "token-id")
|
|
|
|
assert session.execute.called # UPDATE ... WHERE revoked_at IS NULL
|
|
assert session.commit.called
|
|
redis.delete.assert_called_once_with(TOKEN_CACHE_KEY_FMT.format(hash="live-hash"))
|
|
|
|
|
|
def test_revoke_oauth_token_is_idempotent_when_already_revoked():
|
|
"""Second call (or race-loser): no live hash → UPDATE still runs (it
|
|
is itself idempotent thanks to ``WHERE revoked_at IS NULL``) but the
|
|
Redis invalidation is skipped because there's no cache entry to
|
|
drop.
|
|
"""
|
|
session = MagicMock()
|
|
session.query.return_value.filter.return_value.one_or_none.return_value = None
|
|
|
|
redis = MagicMock()
|
|
|
|
revoke_oauth_token(session, redis, "token-id")
|
|
|
|
assert session.execute.called
|
|
assert session.commit.called
|
|
redis.delete.assert_not_called()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# list_active_sessions / token_belongs_to_subject
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_list_active_sessions_returns_session_execute_rows():
|
|
"""Thin delegation: the helper materialises whatever
|
|
``session.execute(...).scalars().all()`` returns into a list. The
|
|
``.scalars()`` step unwraps each one-element ``Row`` so callers see
|
|
bare ``OAuthAccessToken`` entities (matches the declared return
|
|
type).
|
|
"""
|
|
from datetime import UTC, datetime
|
|
|
|
session = MagicMock()
|
|
fake_rows = [MagicMock(), MagicMock()]
|
|
session.execute.return_value.scalars.return_value.all.return_value = fake_rows
|
|
|
|
out = list_active_sessions(session, _account_ctx(), datetime.now(UTC))
|
|
|
|
assert out == fake_rows
|
|
assert session.execute.called
|
|
|
|
|
|
def test_token_belongs_to_subject_true_when_row_present():
|
|
session = MagicMock()
|
|
session.execute.return_value.first.return_value = ("some-id",)
|
|
|
|
assert token_belongs_to_subject(session, "token-id", _account_ctx()) is True
|
|
|
|
|
|
def test_token_belongs_to_subject_false_when_no_row():
|
|
session = MagicMock()
|
|
session.execute.return_value.first.return_value = None
|
|
|
|
assert token_belongs_to_subject(session, "token-id", _account_ctx()) is False
|