From cb2b3e07baf16cd0d0b88bec249fdf50ef61b8bc Mon Sep 17 00:00:00 2001 From: Yansong Zhang <916125788@qq.com> Date: Fri, 6 Feb 2026 10:08:46 +0800 Subject: [PATCH] fix: Standardized adjustment --- api/controllers/service_api/wraps.py | 78 +++++++++++++--------------- api/libs/api_token_cache.py | 52 +++++++++---------- 2 files changed, 61 insertions(+), 69 deletions(-) diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 95f7d71331..12e54abb12 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -327,6 +327,37 @@ def validate_and_get_api_token(scope: str | None = None): # Cache miss - use Redis lock for single-flight mode # This ensures only one request queries DB for the same token concurrently + return _fetch_token_with_single_flight(auth_token, scope) + + +def _query_token_from_db(auth_token: str, scope: str | None) -> ApiToken: + """ + Query API token from database, update last_used_at, and cache the result. + + Raises Unauthorized if token is invalid. + """ + with Session(db.engine, expire_on_commit=False) as session: + current_time = naive_utc_now() + update_token_last_used_at(auth_token, scope, current_time, session=session) + + stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope) + api_token = session.scalar(stmt) + + if not api_token: + ApiTokenCache.set(auth_token, scope, None) + raise Unauthorized("Access token is invalid") + + ApiTokenCache.set(auth_token, scope, api_token) + return api_token + + +def _fetch_token_with_single_flight(auth_token: str, scope: str | None) -> ApiToken: + """ + Fetch token from DB with single-flight pattern using Redis lock. + + Ensures only one concurrent request queries the database for the same token. + Falls back to direct query if lock acquisition fails. + """ logger.debug("Token cache miss, attempting to acquire query lock for scope: %s", scope) lock_key = f"api_token_query_lock:{scope}:{auth_token}" @@ -342,54 +373,17 @@ def validate_and_get_api_token(scope: str | None = None): logger.debug("Token cached by concurrent request, using cached version") return cached_token - # Still not cached - query database - with Session(db.engine, expire_on_commit=False) as session: - current_time = naive_utc_now() - update_token_last_used_at(auth_token, scope, current_time, session=session) - - stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope) - api_token = session.scalar(stmt) - - if not api_token: - ApiTokenCache.set(auth_token, scope, None) - raise Unauthorized("Access token is invalid") - - ApiTokenCache.set(auth_token, scope, api_token) - return api_token + return _query_token_from_db(auth_token, scope) finally: lock.release() else: - # Lock acquisition timeout - fallback to direct query logger.warning("Lock timeout for token: %s, proceeding with direct query", auth_token[:10]) - with Session(db.engine, expire_on_commit=False) as session: - current_time = naive_utc_now() - update_token_last_used_at(auth_token, scope, current_time, session=session) - - stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope) - api_token = session.scalar(stmt) - - if not api_token: - ApiTokenCache.set(auth_token, scope, None) - raise Unauthorized("Access token is invalid") - - ApiTokenCache.set(auth_token, scope, api_token) - return api_token + return _query_token_from_db(auth_token, scope) + except Unauthorized: + raise except Exception as e: - # Redis lock failure - fallback to direct query to ensure service availability logger.warning("Redis lock failed for token query: %s, proceeding anyway", e) - with Session(db.engine, expire_on_commit=False) as session: - current_time = naive_utc_now() - update_token_last_used_at(auth_token, scope, current_time, session=session) - - stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope) - api_token = session.scalar(stmt) - - if not api_token: - ApiTokenCache.set(auth_token, scope, None) - raise Unauthorized("Access token is invalid") - - ApiTokenCache.set(auth_token, scope, api_token) - return api_token + return _query_token_from_db(auth_token, scope) def _async_update_token_last_used_at(auth_token: str, scope: str | None): diff --git a/api/libs/api_token_cache.py b/api/libs/api_token_cache.py index 6c717d05d9..1326a62407 100644 --- a/api/libs/api_token_cache.py +++ b/api/libs/api_token_cache.py @@ -8,7 +8,6 @@ import logging from datetime import datetime from typing import Any -import orjson from pydantic import BaseModel from extensions.ext_redis import redis_client, redis_fallback @@ -71,7 +70,7 @@ class ApiTokenCache: @staticmethod def _serialize_token(api_token: Any) -> bytes: """ - Serialize ApiToken object to JSON bytes using orjson for better performance. + Serialize ApiToken object to JSON bytes. Args: api_token: ApiToken model instance or CachedApiToken @@ -79,27 +78,26 @@ class ApiTokenCache: Returns: JSON bytes representation """ - # If it's already a Pydantic model, use model_dump + # If it's already a Pydantic model, use model_dump_json directly if isinstance(api_token, CachedApiToken): - # Pydantic model -> dict -> orjson - return orjson.dumps(api_token.model_dump(mode="json")) - - # Otherwise, convert from SQLAlchemy model - data = { - "id": str(api_token.id), - "app_id": str(api_token.app_id) if api_token.app_id else None, - "tenant_id": str(api_token.tenant_id) if api_token.tenant_id else None, - "type": api_token.type, - "token": api_token.token, - "last_used_at": api_token.last_used_at.isoformat() if api_token.last_used_at else None, - "created_at": api_token.created_at.isoformat() if api_token.created_at else None, - } - return orjson.dumps(data) + return api_token.model_dump_json().encode("utf-8") + + # Otherwise, convert from SQLAlchemy model to CachedApiToken first + cached = CachedApiToken( + id=str(api_token.id), + app_id=str(api_token.app_id) if api_token.app_id else None, + tenant_id=str(api_token.tenant_id) if api_token.tenant_id else None, + type=api_token.type, + token=api_token.token, + last_used_at=api_token.last_used_at, + created_at=api_token.created_at, + ) + return cached.model_dump_json().encode("utf-8") @staticmethod def _deserialize_token(cached_data: bytes | str) -> Any: """ - Deserialize JSON bytes/string back to a CachedApiToken Pydantic model using orjson. + Deserialize JSON bytes/string back to a CachedApiToken Pydantic model. Args: cached_data: JSON bytes or string from cache @@ -112,12 +110,11 @@ class ApiTokenCache: return None try: - # orjson.loads accepts bytes or str - data = orjson.loads(cached_data) - # Use Pydantic's model_validate for automatic validation - token_obj = CachedApiToken.model_validate(data) - return token_obj - except (ValueError, orjson.JSONDecodeError) as e: + # Pydantic's model_validate_json handles both bytes and str + if isinstance(cached_data, bytes): + cached_data = cached_data.decode("utf-8") + return CachedApiToken.model_validate_json(cached_data) + except (ValueError, Exception) as e: logger.warning("Failed to deserialize token from cache: %s", e) return None @@ -141,7 +138,7 @@ class ApiTokenCache: logger.debug("Cache miss for token key: %s", cache_key) return None - # orjson.loads handles both bytes and str automatically + # Pydantic handles deserialization logger.debug("Cache hit for token key: %s", cache_key) return ApiTokenCache._deserialize_token(cached_data) @@ -259,8 +256,9 @@ class ApiTokenCache: try: cached_data = redis_client.get(cache_key) if cached_data and cached_data != b"null": - data = orjson.loads(cached_data) - tenant_id = data.get("tenant_id") + cached_token = ApiTokenCache._deserialize_token(cached_data) + if cached_token: + tenant_id = cached_token.tenant_id except Exception as e: # If we can't get tenant_id, just delete the key without index cleanup logger.debug("Failed to get tenant_id for cache cleanup: %s", e)