diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index df0ee8d69f..aff364d561 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -325,27 +325,71 @@ def validate_and_get_api_token(scope: str | None = None): _async_update_token_last_used_at(auth_token, scope) return cached_token - # Cache miss - query database - logger.debug("Token cache miss, querying database for scope: %s", scope) - current_time = naive_utc_now() + # Cache miss - use Redis lock for single-flight mode + # This ensures only one request queries DB for the same token concurrently + logger.debug("Token cache miss, attempting to acquire query lock for scope: %s", scope) + + lock_key = f"api_token_query_lock:{scope}:{auth_token}" + lock = redis_client.lock(lock_key, timeout=10, blocking_timeout=5) + + try: + if lock.acquire(blocking=True): + try: + # Double-check cache after acquiring lock + # (another concurrent request might have already cached it) + cached_token = ApiTokenCache.get(auth_token, scope) + if cached_token is not None: + logger.debug("Token cached by concurrent request, using cached version") + return cached_token + + # Still not cached - query database + with Session(db.engine, expire_on_commit=False) as session: + current_time = naive_utc_now() + update_token_last_used_at(auth_token, scope, current_time, session=session) + + stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope) + api_token = session.scalar(stmt) - with Session(db.engine, expire_on_commit=False) as session: - # Use unified update method to avoid code duplication with Celery task - update_token_last_used_at(auth_token, scope, current_time, session=session) - - # Query the token - stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope) - api_token = session.scalar(stmt) + if not api_token: + ApiTokenCache.set(auth_token, scope, None) + raise Unauthorized("Access token is invalid") - if not api_token: - # Cache the null result to prevent cache penetration attacks - ApiTokenCache.set(auth_token, scope, None) - raise Unauthorized("Access token is invalid") + ApiTokenCache.set(auth_token, scope, api_token) + return api_token + finally: + lock.release() + else: + # Lock acquisition timeout - fallback to direct query + logger.warning("Lock timeout for token: %s, proceeding with direct query", auth_token[:10]) + with Session(db.engine, expire_on_commit=False) as session: + current_time = naive_utc_now() + update_token_last_used_at(auth_token, scope, current_time, session=session) + + stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope) + api_token = session.scalar(stmt) - # Cache the valid token - ApiTokenCache.set(auth_token, scope, api_token) + if not api_token: + ApiTokenCache.set(auth_token, scope, None) + raise Unauthorized("Access token is invalid") - return api_token + ApiTokenCache.set(auth_token, scope, api_token) + return api_token + except Exception as e: + # Redis lock failure - fallback to direct query to ensure service availability + logger.warning("Redis lock failed for token query: %s, proceeding anyway", e) + with Session(db.engine, expire_on_commit=False) as session: + current_time = naive_utc_now() + update_token_last_used_at(auth_token, scope, current_time, session=session) + + stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope) + api_token = session.scalar(stmt) + + if not api_token: + ApiTokenCache.set(auth_token, scope, None) + raise Unauthorized("Access token is invalid") + + ApiTokenCache.set(auth_token, scope, api_token) + return api_token def _async_update_token_last_used_at(auth_token: str, scope: str | None): diff --git a/api/libs/api_token_cache.py b/api/libs/api_token_cache.py index 86aad97b90..14d264e39f 100644 --- a/api/libs/api_token_cache.py +++ b/api/libs/api_token_cache.py @@ -9,36 +9,33 @@ import logging from datetime import datetime from typing import Any +from pydantic import BaseModel + from extensions.ext_redis import redis_client, redis_fallback logger = logging.getLogger(__name__) -class CachedApiToken: +class CachedApiToken(BaseModel): """ - Simple data class to represent a cached API token. + Pydantic model for cached API token data. - This is NOT a SQLAlchemy model instance, but a plain Python object + This is NOT a SQLAlchemy model instance, but a plain Pydantic model that mimics the ApiToken model interface for read-only access. + + Using Pydantic provides: + - Automatic type validation + - Better IDE support + - Built-in serialization/deserialization """ - def __init__( - self, - id: str, - app_id: str | None, - tenant_id: str | None, - type: str, - token: str, - last_used_at: datetime | None, - created_at: datetime | None, - ): - self.id = id - self.app_id = app_id - self.tenant_id = tenant_id - self.type = type - self.token = token - self.last_used_at = last_used_at - self.created_at = created_at + id: str + app_id: str | None + tenant_id: str | None + type: str + token: str + last_used_at: datetime | None + created_at: datetime | None def __repr__(self) -> str: return f"" @@ -77,11 +74,16 @@ class ApiTokenCache: Serialize ApiToken object to JSON string. Args: - api_token: ApiToken model instance + api_token: ApiToken model instance or CachedApiToken Returns: JSON string representation """ + # If it's already a Pydantic model, use model_dump_json + if isinstance(api_token, CachedApiToken): + return api_token.model_dump_json() + + # Otherwise, convert from SQLAlchemy model data = { "id": str(api_token.id), "app_id": str(api_token.app_id) if api_token.app_id else None, @@ -96,7 +98,7 @@ class ApiTokenCache: @staticmethod def _deserialize_token(cached_data: str) -> Any: """ - Deserialize JSON string back to a CachedApiToken object. + Deserialize JSON string back to a CachedApiToken Pydantic model. Args: cached_data: JSON string from cache @@ -109,22 +111,10 @@ class ApiTokenCache: return None try: - data = json.loads(cached_data) - - # Create a simple data object (NOT a SQLAlchemy model instance) - # This is safe because it's just a plain Python object with attributes - token_obj = CachedApiToken( - id=data["id"], - app_id=data["app_id"], - tenant_id=data["tenant_id"], - type=data["type"], - token=data["token"], - last_used_at=datetime.fromisoformat(data["last_used_at"]) if data["last_used_at"] else None, - created_at=datetime.fromisoformat(data["created_at"]) if data["created_at"] else None, - ) - + # Use Pydantic's model_validate_json for automatic validation + token_obj = CachedApiToken.model_validate_json(cached_data) return token_obj - except (json.JSONDecodeError, KeyError, ValueError) as e: + except (json.JSONDecodeError, ValueError) as e: logger.warning("Failed to deserialize token from cache: %s", e) return None @@ -289,92 +279,3 @@ class ApiTokenCache: except Exception as e: logger.warning("Failed to delete token cache: %s", e) return False - - @staticmethod - @redis_fallback(default_return=False) - def invalidate_by_tenant(tenant_id: str) -> bool: - """ - Invalidate all API token caches for a specific tenant. - Use this when tenant status changes or tokens are batch updated. - - Uses a two-tier approach: - 1. Try to use tenant index (fast, O(n) where n = tenant's tokens) - 2. Fallback to full scan if index doesn't exist (slow, O(N) where N = all tokens) - - Args: - tenant_id: The tenant ID - - Returns: - True if successful, False otherwise - """ - try: - # Try using tenant index first (efficient approach) - index_key = f"tenant_tokens:{tenant_id}" - cache_keys = redis_client.smembers(index_key) - - if cache_keys: - # Index exists - use it (fast path) - deleted_count = 0 - for cache_key in cache_keys: - if isinstance(cache_key, bytes): - cache_key = cache_key.decode("utf-8") - redis_client.delete(cache_key) - deleted_count += 1 - - # Delete the index itself - redis_client.delete(index_key) - - logger.info( - "Invalidated %d token cache entries for tenant: %s (via index)", - deleted_count, - tenant_id, - ) - return True - - # Index doesn't exist - fallback to scanning (slow path) - logger.info("Tenant index not found, falling back to full scan for tenant: %s", tenant_id) - - pattern = f"{CACHE_KEY_PREFIX}:*" - cursor = 0 - deleted_count = 0 - checked_count = 0 - - while True: - cursor, keys = redis_client.scan(cursor, match=pattern, count=100) - if keys: - for key in keys: - checked_count += 1 - try: - # Fetch and check if this token belongs to the tenant - cached_data = redis_client.get(key) - if cached_data: - # Decode if bytes - if isinstance(cached_data, bytes): - cached_data = cached_data.decode("utf-8") - - # Skip null values - if cached_data == "null": - continue - - # Deserialize and check tenant_id - data = json.loads(cached_data) - if data.get("tenant_id") == tenant_id: - redis_client.delete(key) - deleted_count += 1 - except (json.JSONDecodeError, Exception) as e: - logger.warning("Failed to check cache key %s: %s", key, e) - continue - - if cursor == 0: - break - - logger.info( - "Invalidated %d token cache entries for tenant: %s (checked %d keys via scan)", - deleted_count, - tenant_id, - checked_count, - ) - return True - except Exception as e: - logger.warning("Failed to invalidate tenant token cache: %s", e) - return False diff --git a/api/tests/integration_tests/libs/test_api_token_cache_integration.py b/api/tests/integration_tests/libs/test_api_token_cache_integration.py index 8026ca6902..1e21a69ba4 100644 --- a/api/tests/integration_tests/libs/test_api_token_cache_integration.py +++ b/api/tests/integration_tests/libs/test_api_token_cache_integration.py @@ -167,43 +167,6 @@ class TestApiTokenCacheRedisIntegration: cache_keys = [m.decode('utf-8') if isinstance(m, bytes) else m for m in members] assert self.cache_key in cache_keys - def test_invalidate_by_tenant_via_index(self): - """Test tenant-wide cache invalidation using index (fast path).""" - from unittest.mock import MagicMock - - tenant_id = "test-tenant-id" - - # Create multiple tokens for the same tenant - for i in range(3): - token_value = f"test-token-{i}" - mock_token = MagicMock() - mock_token.id = f"test-id-{i}" - mock_token.app_id = "test-app" - mock_token.tenant_id = tenant_id - mock_token.type = "app" - mock_token.token = token_value - mock_token.last_used_at = None - mock_token.created_at = datetime.now() - - ApiTokenCache.set(token_value, "app", mock_token) - - # Verify all cached - for i in range(3): - key = f"api_token:app:test-token-{i}" - assert redis_client.exists(key) == 1 - - # Invalidate by tenant - result = ApiTokenCache.invalidate_by_tenant(tenant_id) - assert result is True - - # Verify all deleted - for i in range(3): - key = f"api_token:app:test-token-{i}" - assert redis_client.exists(key) == 0 - - # Verify index also deleted - assert redis_client.exists(f"tenant_tokens:{tenant_id}") == 0 - def test_concurrent_cache_access(self): """Test concurrent cache access doesn't cause issues.""" import concurrent.futures