fix: Standardized adjustment

This commit is contained in:
Yansong Zhang
2026-02-06 10:08:46 +08:00
parent 8cbd1af0d1
commit cb2b3e07ba
2 changed files with 61 additions and 69 deletions

View File

@ -327,6 +327,37 @@ def validate_and_get_api_token(scope: str | None = None):
# Cache miss - use Redis lock for single-flight mode # Cache miss - use Redis lock for single-flight mode
# This ensures only one request queries DB for the same token concurrently # This ensures only one request queries DB for the same token concurrently
return _fetch_token_with_single_flight(auth_token, scope)
def _query_token_from_db(auth_token: str, scope: str | None) -> ApiToken:
"""
Query API token from database, update last_used_at, and cache the result.
Raises Unauthorized if token is invalid.
"""
with Session(db.engine, expire_on_commit=False) as session:
current_time = naive_utc_now()
update_token_last_used_at(auth_token, scope, current_time, session=session)
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
api_token = session.scalar(stmt)
if not api_token:
ApiTokenCache.set(auth_token, scope, None)
raise Unauthorized("Access token is invalid")
ApiTokenCache.set(auth_token, scope, api_token)
return api_token
def _fetch_token_with_single_flight(auth_token: str, scope: str | None) -> ApiToken:
"""
Fetch token from DB with single-flight pattern using Redis lock.
Ensures only one concurrent request queries the database for the same token.
Falls back to direct query if lock acquisition fails.
"""
logger.debug("Token cache miss, attempting to acquire query lock for scope: %s", scope) logger.debug("Token cache miss, attempting to acquire query lock for scope: %s", scope)
lock_key = f"api_token_query_lock:{scope}:{auth_token}" lock_key = f"api_token_query_lock:{scope}:{auth_token}"
@ -342,54 +373,17 @@ def validate_and_get_api_token(scope: str | None = None):
logger.debug("Token cached by concurrent request, using cached version") logger.debug("Token cached by concurrent request, using cached version")
return cached_token return cached_token
# Still not cached - query database return _query_token_from_db(auth_token, scope)
with Session(db.engine, expire_on_commit=False) as session:
current_time = naive_utc_now()
update_token_last_used_at(auth_token, scope, current_time, session=session)
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
api_token = session.scalar(stmt)
if not api_token:
ApiTokenCache.set(auth_token, scope, None)
raise Unauthorized("Access token is invalid")
ApiTokenCache.set(auth_token, scope, api_token)
return api_token
finally: finally:
lock.release() lock.release()
else: else:
# Lock acquisition timeout - fallback to direct query
logger.warning("Lock timeout for token: %s, proceeding with direct query", auth_token[:10]) logger.warning("Lock timeout for token: %s, proceeding with direct query", auth_token[:10])
with Session(db.engine, expire_on_commit=False) as session: return _query_token_from_db(auth_token, scope)
current_time = naive_utc_now() except Unauthorized:
update_token_last_used_at(auth_token, scope, current_time, session=session) raise
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
api_token = session.scalar(stmt)
if not api_token:
ApiTokenCache.set(auth_token, scope, None)
raise Unauthorized("Access token is invalid")
ApiTokenCache.set(auth_token, scope, api_token)
return api_token
except Exception as e: except Exception as e:
# Redis lock failure - fallback to direct query to ensure service availability
logger.warning("Redis lock failed for token query: %s, proceeding anyway", e) logger.warning("Redis lock failed for token query: %s, proceeding anyway", e)
with Session(db.engine, expire_on_commit=False) as session: return _query_token_from_db(auth_token, scope)
current_time = naive_utc_now()
update_token_last_used_at(auth_token, scope, current_time, session=session)
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
api_token = session.scalar(stmt)
if not api_token:
ApiTokenCache.set(auth_token, scope, None)
raise Unauthorized("Access token is invalid")
ApiTokenCache.set(auth_token, scope, api_token)
return api_token
def _async_update_token_last_used_at(auth_token: str, scope: str | None): def _async_update_token_last_used_at(auth_token: str, scope: str | None):

View File

@ -8,7 +8,6 @@ import logging
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
import orjson
from pydantic import BaseModel from pydantic import BaseModel
from extensions.ext_redis import redis_client, redis_fallback from extensions.ext_redis import redis_client, redis_fallback
@ -71,7 +70,7 @@ class ApiTokenCache:
@staticmethod @staticmethod
def _serialize_token(api_token: Any) -> bytes: def _serialize_token(api_token: Any) -> bytes:
""" """
Serialize ApiToken object to JSON bytes using orjson for better performance. Serialize ApiToken object to JSON bytes.
Args: Args:
api_token: ApiToken model instance or CachedApiToken api_token: ApiToken model instance or CachedApiToken
@ -79,27 +78,26 @@ class ApiTokenCache:
Returns: Returns:
JSON bytes representation JSON bytes representation
""" """
# If it's already a Pydantic model, use model_dump # If it's already a Pydantic model, use model_dump_json directly
if isinstance(api_token, CachedApiToken): if isinstance(api_token, CachedApiToken):
# Pydantic model -> dict -> orjson return api_token.model_dump_json().encode("utf-8")
return orjson.dumps(api_token.model_dump(mode="json"))
# Otherwise, convert from SQLAlchemy model to CachedApiToken first
# Otherwise, convert from SQLAlchemy model cached = CachedApiToken(
data = { id=str(api_token.id),
"id": str(api_token.id), app_id=str(api_token.app_id) if api_token.app_id else None,
"app_id": str(api_token.app_id) if api_token.app_id else None, tenant_id=str(api_token.tenant_id) if api_token.tenant_id else None,
"tenant_id": str(api_token.tenant_id) if api_token.tenant_id else None, type=api_token.type,
"type": api_token.type, token=api_token.token,
"token": api_token.token, last_used_at=api_token.last_used_at,
"last_used_at": api_token.last_used_at.isoformat() if api_token.last_used_at else None, created_at=api_token.created_at,
"created_at": api_token.created_at.isoformat() if api_token.created_at else None, )
} return cached.model_dump_json().encode("utf-8")
return orjson.dumps(data)
@staticmethod @staticmethod
def _deserialize_token(cached_data: bytes | str) -> Any: def _deserialize_token(cached_data: bytes | str) -> Any:
""" """
Deserialize JSON bytes/string back to a CachedApiToken Pydantic model using orjson. Deserialize JSON bytes/string back to a CachedApiToken Pydantic model.
Args: Args:
cached_data: JSON bytes or string from cache cached_data: JSON bytes or string from cache
@ -112,12 +110,11 @@ class ApiTokenCache:
return None return None
try: try:
# orjson.loads accepts bytes or str # Pydantic's model_validate_json handles both bytes and str
data = orjson.loads(cached_data) if isinstance(cached_data, bytes):
# Use Pydantic's model_validate for automatic validation cached_data = cached_data.decode("utf-8")
token_obj = CachedApiToken.model_validate(data) return CachedApiToken.model_validate_json(cached_data)
return token_obj except (ValueError, Exception) as e:
except (ValueError, orjson.JSONDecodeError) as e:
logger.warning("Failed to deserialize token from cache: %s", e) logger.warning("Failed to deserialize token from cache: %s", e)
return None return None
@ -141,7 +138,7 @@ class ApiTokenCache:
logger.debug("Cache miss for token key: %s", cache_key) logger.debug("Cache miss for token key: %s", cache_key)
return None return None
# orjson.loads handles both bytes and str automatically # Pydantic handles deserialization
logger.debug("Cache hit for token key: %s", cache_key) logger.debug("Cache hit for token key: %s", cache_key)
return ApiTokenCache._deserialize_token(cached_data) return ApiTokenCache._deserialize_token(cached_data)
@ -259,8 +256,9 @@ class ApiTokenCache:
try: try:
cached_data = redis_client.get(cache_key) cached_data = redis_client.get(cache_key)
if cached_data and cached_data != b"null": if cached_data and cached_data != b"null":
data = orjson.loads(cached_data) cached_token = ApiTokenCache._deserialize_token(cached_data)
tenant_id = data.get("tenant_id") if cached_token:
tenant_id = cached_token.tenant_id
except Exception as e: except Exception as e:
# If we can't get tenant_id, just delete the key without index cleanup # If we can't get tenant_id, just delete the key without index cleanup
logger.debug("Failed to get tenant_id for cache cleanup: %s", e) logger.debug("Failed to get tenant_id for cache cleanup: %s", e)