fix: Standardized adjustment

This commit is contained in:
Yansong Zhang
2026-02-06 10:08:46 +08:00
parent 8cbd1af0d1
commit cb2b3e07ba
2 changed files with 61 additions and 69 deletions

View File

@ -327,6 +327,37 @@ def validate_and_get_api_token(scope: str | None = None):
# Cache miss - use Redis lock for single-flight mode
# This ensures only one request queries DB for the same token concurrently
return _fetch_token_with_single_flight(auth_token, scope)
def _query_token_from_db(auth_token: str, scope: str | None) -> ApiToken:
"""
Query API token from database, update last_used_at, and cache the result.
Raises Unauthorized if token is invalid.
"""
with Session(db.engine, expire_on_commit=False) as session:
current_time = naive_utc_now()
update_token_last_used_at(auth_token, scope, current_time, session=session)
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
api_token = session.scalar(stmt)
if not api_token:
ApiTokenCache.set(auth_token, scope, None)
raise Unauthorized("Access token is invalid")
ApiTokenCache.set(auth_token, scope, api_token)
return api_token
def _fetch_token_with_single_flight(auth_token: str, scope: str | None) -> ApiToken:
"""
Fetch token from DB with single-flight pattern using Redis lock.
Ensures only one concurrent request queries the database for the same token.
Falls back to direct query if lock acquisition fails.
"""
logger.debug("Token cache miss, attempting to acquire query lock for scope: %s", scope)
lock_key = f"api_token_query_lock:{scope}:{auth_token}"
@ -342,54 +373,17 @@ def validate_and_get_api_token(scope: str | None = None):
logger.debug("Token cached by concurrent request, using cached version")
return cached_token
# Still not cached - query database
with Session(db.engine, expire_on_commit=False) as session:
current_time = naive_utc_now()
update_token_last_used_at(auth_token, scope, current_time, session=session)
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
api_token = session.scalar(stmt)
if not api_token:
ApiTokenCache.set(auth_token, scope, None)
raise Unauthorized("Access token is invalid")
ApiTokenCache.set(auth_token, scope, api_token)
return api_token
return _query_token_from_db(auth_token, scope)
finally:
lock.release()
else:
# Lock acquisition timeout - fallback to direct query
logger.warning("Lock timeout for token: %s, proceeding with direct query", auth_token[:10])
with Session(db.engine, expire_on_commit=False) as session:
current_time = naive_utc_now()
update_token_last_used_at(auth_token, scope, current_time, session=session)
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
api_token = session.scalar(stmt)
if not api_token:
ApiTokenCache.set(auth_token, scope, None)
raise Unauthorized("Access token is invalid")
ApiTokenCache.set(auth_token, scope, api_token)
return api_token
return _query_token_from_db(auth_token, scope)
except Unauthorized:
raise
except Exception as e:
# Redis lock failure - fallback to direct query to ensure service availability
logger.warning("Redis lock failed for token query: %s, proceeding anyway", e)
with Session(db.engine, expire_on_commit=False) as session:
current_time = naive_utc_now()
update_token_last_used_at(auth_token, scope, current_time, session=session)
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
api_token = session.scalar(stmt)
if not api_token:
ApiTokenCache.set(auth_token, scope, None)
raise Unauthorized("Access token is invalid")
ApiTokenCache.set(auth_token, scope, api_token)
return api_token
return _query_token_from_db(auth_token, scope)
def _async_update_token_last_used_at(auth_token: str, scope: str | None):

View File

@ -8,7 +8,6 @@ import logging
from datetime import datetime
from typing import Any
import orjson
from pydantic import BaseModel
from extensions.ext_redis import redis_client, redis_fallback
@ -71,7 +70,7 @@ class ApiTokenCache:
@staticmethod
def _serialize_token(api_token: Any) -> bytes:
"""
Serialize ApiToken object to JSON bytes using orjson for better performance.
Serialize ApiToken object to JSON bytes.
Args:
api_token: ApiToken model instance or CachedApiToken
@ -79,27 +78,26 @@ class ApiTokenCache:
Returns:
JSON bytes representation
"""
# If it's already a Pydantic model, use model_dump
# If it's already a Pydantic model, use model_dump_json directly
if isinstance(api_token, CachedApiToken):
# Pydantic model -> dict -> orjson
return orjson.dumps(api_token.model_dump(mode="json"))
# Otherwise, convert from SQLAlchemy model
data = {
"id": str(api_token.id),
"app_id": str(api_token.app_id) if api_token.app_id else None,
"tenant_id": str(api_token.tenant_id) if api_token.tenant_id else None,
"type": api_token.type,
"token": api_token.token,
"last_used_at": api_token.last_used_at.isoformat() if api_token.last_used_at else None,
"created_at": api_token.created_at.isoformat() if api_token.created_at else None,
}
return orjson.dumps(data)
return api_token.model_dump_json().encode("utf-8")
# Otherwise, convert from SQLAlchemy model to CachedApiToken first
cached = CachedApiToken(
id=str(api_token.id),
app_id=str(api_token.app_id) if api_token.app_id else None,
tenant_id=str(api_token.tenant_id) if api_token.tenant_id else None,
type=api_token.type,
token=api_token.token,
last_used_at=api_token.last_used_at,
created_at=api_token.created_at,
)
return cached.model_dump_json().encode("utf-8")
@staticmethod
def _deserialize_token(cached_data: bytes | str) -> Any:
"""
Deserialize JSON bytes/string back to a CachedApiToken Pydantic model using orjson.
Deserialize JSON bytes/string back to a CachedApiToken Pydantic model.
Args:
cached_data: JSON bytes or string from cache
@ -112,12 +110,11 @@ class ApiTokenCache:
return None
try:
# orjson.loads accepts bytes or str
data = orjson.loads(cached_data)
# Use Pydantic's model_validate for automatic validation
token_obj = CachedApiToken.model_validate(data)
return token_obj
except (ValueError, orjson.JSONDecodeError) as e:
# Pydantic's model_validate_json handles both bytes and str
if isinstance(cached_data, bytes):
cached_data = cached_data.decode("utf-8")
return CachedApiToken.model_validate_json(cached_data)
except (ValueError, Exception) as e:
logger.warning("Failed to deserialize token from cache: %s", e)
return None
@ -141,7 +138,7 @@ class ApiTokenCache:
logger.debug("Cache miss for token key: %s", cache_key)
return None
# orjson.loads handles both bytes and str automatically
# Pydantic handles deserialization
logger.debug("Cache hit for token key: %s", cache_key)
return ApiTokenCache._deserialize_token(cached_data)
@ -259,8 +256,9 @@ class ApiTokenCache:
try:
cached_data = redis_client.get(cache_key)
if cached_data and cached_data != b"null":
data = orjson.loads(cached_data)
tenant_id = data.get("tenant_id")
cached_token = ApiTokenCache._deserialize_token(cached_data)
if cached_token:
tenant_id = cached_token.tenant_id
except Exception as e:
# If we can't get tenant_id, just delete the key without index cleanup
logger.debug("Failed to get tenant_id for cache cleanup: %s", e)