mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 09:58:04 +08:00
fix: Standardized adjustment
This commit is contained in:
@ -327,6 +327,37 @@ def validate_and_get_api_token(scope: str | None = None):
|
|||||||
|
|
||||||
# Cache miss - use Redis lock for single-flight mode
|
# Cache miss - use Redis lock for single-flight mode
|
||||||
# This ensures only one request queries DB for the same token concurrently
|
# This ensures only one request queries DB for the same token concurrently
|
||||||
|
return _fetch_token_with_single_flight(auth_token, scope)
|
||||||
|
|
||||||
|
|
||||||
|
def _query_token_from_db(auth_token: str, scope: str | None) -> ApiToken:
|
||||||
|
"""
|
||||||
|
Query API token from database, update last_used_at, and cache the result.
|
||||||
|
|
||||||
|
Raises Unauthorized if token is invalid.
|
||||||
|
"""
|
||||||
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
|
current_time = naive_utc_now()
|
||||||
|
update_token_last_used_at(auth_token, scope, current_time, session=session)
|
||||||
|
|
||||||
|
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
||||||
|
api_token = session.scalar(stmt)
|
||||||
|
|
||||||
|
if not api_token:
|
||||||
|
ApiTokenCache.set(auth_token, scope, None)
|
||||||
|
raise Unauthorized("Access token is invalid")
|
||||||
|
|
||||||
|
ApiTokenCache.set(auth_token, scope, api_token)
|
||||||
|
return api_token
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_token_with_single_flight(auth_token: str, scope: str | None) -> ApiToken:
|
||||||
|
"""
|
||||||
|
Fetch token from DB with single-flight pattern using Redis lock.
|
||||||
|
|
||||||
|
Ensures only one concurrent request queries the database for the same token.
|
||||||
|
Falls back to direct query if lock acquisition fails.
|
||||||
|
"""
|
||||||
logger.debug("Token cache miss, attempting to acquire query lock for scope: %s", scope)
|
logger.debug("Token cache miss, attempting to acquire query lock for scope: %s", scope)
|
||||||
|
|
||||||
lock_key = f"api_token_query_lock:{scope}:{auth_token}"
|
lock_key = f"api_token_query_lock:{scope}:{auth_token}"
|
||||||
@ -342,54 +373,17 @@ def validate_and_get_api_token(scope: str | None = None):
|
|||||||
logger.debug("Token cached by concurrent request, using cached version")
|
logger.debug("Token cached by concurrent request, using cached version")
|
||||||
return cached_token
|
return cached_token
|
||||||
|
|
||||||
# Still not cached - query database
|
return _query_token_from_db(auth_token, scope)
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
|
||||||
current_time = naive_utc_now()
|
|
||||||
update_token_last_used_at(auth_token, scope, current_time, session=session)
|
|
||||||
|
|
||||||
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
|
||||||
api_token = session.scalar(stmt)
|
|
||||||
|
|
||||||
if not api_token:
|
|
||||||
ApiTokenCache.set(auth_token, scope, None)
|
|
||||||
raise Unauthorized("Access token is invalid")
|
|
||||||
|
|
||||||
ApiTokenCache.set(auth_token, scope, api_token)
|
|
||||||
return api_token
|
|
||||||
finally:
|
finally:
|
||||||
lock.release()
|
lock.release()
|
||||||
else:
|
else:
|
||||||
# Lock acquisition timeout - fallback to direct query
|
|
||||||
logger.warning("Lock timeout for token: %s, proceeding with direct query", auth_token[:10])
|
logger.warning("Lock timeout for token: %s, proceeding with direct query", auth_token[:10])
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
return _query_token_from_db(auth_token, scope)
|
||||||
current_time = naive_utc_now()
|
except Unauthorized:
|
||||||
update_token_last_used_at(auth_token, scope, current_time, session=session)
|
raise
|
||||||
|
|
||||||
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
|
||||||
api_token = session.scalar(stmt)
|
|
||||||
|
|
||||||
if not api_token:
|
|
||||||
ApiTokenCache.set(auth_token, scope, None)
|
|
||||||
raise Unauthorized("Access token is invalid")
|
|
||||||
|
|
||||||
ApiTokenCache.set(auth_token, scope, api_token)
|
|
||||||
return api_token
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Redis lock failure - fallback to direct query to ensure service availability
|
|
||||||
logger.warning("Redis lock failed for token query: %s, proceeding anyway", e)
|
logger.warning("Redis lock failed for token query: %s, proceeding anyway", e)
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
return _query_token_from_db(auth_token, scope)
|
||||||
current_time = naive_utc_now()
|
|
||||||
update_token_last_used_at(auth_token, scope, current_time, session=session)
|
|
||||||
|
|
||||||
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
|
||||||
api_token = session.scalar(stmt)
|
|
||||||
|
|
||||||
if not api_token:
|
|
||||||
ApiTokenCache.set(auth_token, scope, None)
|
|
||||||
raise Unauthorized("Access token is invalid")
|
|
||||||
|
|
||||||
ApiTokenCache.set(auth_token, scope, api_token)
|
|
||||||
return api_token
|
|
||||||
|
|
||||||
|
|
||||||
def _async_update_token_last_used_at(auth_token: str, scope: str | None):
|
def _async_update_token_last_used_at(auth_token: str, scope: str | None):
|
||||||
|
|||||||
@ -8,7 +8,6 @@ import logging
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import orjson
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from extensions.ext_redis import redis_client, redis_fallback
|
from extensions.ext_redis import redis_client, redis_fallback
|
||||||
@ -71,7 +70,7 @@ class ApiTokenCache:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _serialize_token(api_token: Any) -> bytes:
|
def _serialize_token(api_token: Any) -> bytes:
|
||||||
"""
|
"""
|
||||||
Serialize ApiToken object to JSON bytes using orjson for better performance.
|
Serialize ApiToken object to JSON bytes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
api_token: ApiToken model instance or CachedApiToken
|
api_token: ApiToken model instance or CachedApiToken
|
||||||
@ -79,27 +78,26 @@ class ApiTokenCache:
|
|||||||
Returns:
|
Returns:
|
||||||
JSON bytes representation
|
JSON bytes representation
|
||||||
"""
|
"""
|
||||||
# If it's already a Pydantic model, use model_dump
|
# If it's already a Pydantic model, use model_dump_json directly
|
||||||
if isinstance(api_token, CachedApiToken):
|
if isinstance(api_token, CachedApiToken):
|
||||||
# Pydantic model -> dict -> orjson
|
return api_token.model_dump_json().encode("utf-8")
|
||||||
return orjson.dumps(api_token.model_dump(mode="json"))
|
|
||||||
|
# Otherwise, convert from SQLAlchemy model to CachedApiToken first
|
||||||
# Otherwise, convert from SQLAlchemy model
|
cached = CachedApiToken(
|
||||||
data = {
|
id=str(api_token.id),
|
||||||
"id": str(api_token.id),
|
app_id=str(api_token.app_id) if api_token.app_id else None,
|
||||||
"app_id": str(api_token.app_id) if api_token.app_id else None,
|
tenant_id=str(api_token.tenant_id) if api_token.tenant_id else None,
|
||||||
"tenant_id": str(api_token.tenant_id) if api_token.tenant_id else None,
|
type=api_token.type,
|
||||||
"type": api_token.type,
|
token=api_token.token,
|
||||||
"token": api_token.token,
|
last_used_at=api_token.last_used_at,
|
||||||
"last_used_at": api_token.last_used_at.isoformat() if api_token.last_used_at else None,
|
created_at=api_token.created_at,
|
||||||
"created_at": api_token.created_at.isoformat() if api_token.created_at else None,
|
)
|
||||||
}
|
return cached.model_dump_json().encode("utf-8")
|
||||||
return orjson.dumps(data)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _deserialize_token(cached_data: bytes | str) -> Any:
|
def _deserialize_token(cached_data: bytes | str) -> Any:
|
||||||
"""
|
"""
|
||||||
Deserialize JSON bytes/string back to a CachedApiToken Pydantic model using orjson.
|
Deserialize JSON bytes/string back to a CachedApiToken Pydantic model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cached_data: JSON bytes or string from cache
|
cached_data: JSON bytes or string from cache
|
||||||
@ -112,12 +110,11 @@ class ApiTokenCache:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# orjson.loads accepts bytes or str
|
# Pydantic's model_validate_json handles both bytes and str
|
||||||
data = orjson.loads(cached_data)
|
if isinstance(cached_data, bytes):
|
||||||
# Use Pydantic's model_validate for automatic validation
|
cached_data = cached_data.decode("utf-8")
|
||||||
token_obj = CachedApiToken.model_validate(data)
|
return CachedApiToken.model_validate_json(cached_data)
|
||||||
return token_obj
|
except (ValueError, Exception) as e:
|
||||||
except (ValueError, orjson.JSONDecodeError) as e:
|
|
||||||
logger.warning("Failed to deserialize token from cache: %s", e)
|
logger.warning("Failed to deserialize token from cache: %s", e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -141,7 +138,7 @@ class ApiTokenCache:
|
|||||||
logger.debug("Cache miss for token key: %s", cache_key)
|
logger.debug("Cache miss for token key: %s", cache_key)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# orjson.loads handles both bytes and str automatically
|
# Pydantic handles deserialization
|
||||||
logger.debug("Cache hit for token key: %s", cache_key)
|
logger.debug("Cache hit for token key: %s", cache_key)
|
||||||
return ApiTokenCache._deserialize_token(cached_data)
|
return ApiTokenCache._deserialize_token(cached_data)
|
||||||
|
|
||||||
@ -259,8 +256,9 @@ class ApiTokenCache:
|
|||||||
try:
|
try:
|
||||||
cached_data = redis_client.get(cache_key)
|
cached_data = redis_client.get(cache_key)
|
||||||
if cached_data and cached_data != b"null":
|
if cached_data and cached_data != b"null":
|
||||||
data = orjson.loads(cached_data)
|
cached_token = ApiTokenCache._deserialize_token(cached_data)
|
||||||
tenant_id = data.get("tenant_id")
|
if cached_token:
|
||||||
|
tenant_id = cached_token.tenant_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# If we can't get tenant_id, just delete the key without index cleanup
|
# If we can't get tenant_id, just delete the key without index cleanup
|
||||||
logger.debug("Failed to get tenant_id for cache cleanup: %s", e)
|
logger.debug("Failed to get tenant_id for cache cleanup: %s", e)
|
||||||
|
|||||||
Reference in New Issue
Block a user