mirror of
https://github.com/langgenius/dify.git
synced 2026-02-27 21:17:13 +08:00
make it great agin
This commit is contained in:
@ -325,27 +325,71 @@ def validate_and_get_api_token(scope: str | None = None):
|
||||
_async_update_token_last_used_at(auth_token, scope)
|
||||
return cached_token
|
||||
|
||||
# Cache miss - query database
|
||||
logger.debug("Token cache miss, querying database for scope: %s", scope)
|
||||
current_time = naive_utc_now()
|
||||
# Cache miss - use Redis lock for single-flight mode
|
||||
# This ensures only one request queries DB for the same token concurrently
|
||||
logger.debug("Token cache miss, attempting to acquire query lock for scope: %s", scope)
|
||||
|
||||
lock_key = f"api_token_query_lock:{scope}:{auth_token}"
|
||||
lock = redis_client.lock(lock_key, timeout=10, blocking_timeout=5)
|
||||
|
||||
try:
|
||||
if lock.acquire(blocking=True):
|
||||
try:
|
||||
# Double-check cache after acquiring lock
|
||||
# (another concurrent request might have already cached it)
|
||||
cached_token = ApiTokenCache.get(auth_token, scope)
|
||||
if cached_token is not None:
|
||||
logger.debug("Token cached by concurrent request, using cached version")
|
||||
return cached_token
|
||||
|
||||
# Still not cached - query database
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
current_time = naive_utc_now()
|
||||
update_token_last_used_at(auth_token, scope, current_time, session=session)
|
||||
|
||||
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
||||
api_token = session.scalar(stmt)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Use unified update method to avoid code duplication with Celery task
|
||||
update_token_last_used_at(auth_token, scope, current_time, session=session)
|
||||
|
||||
# Query the token
|
||||
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
||||
api_token = session.scalar(stmt)
|
||||
if not api_token:
|
||||
ApiTokenCache.set(auth_token, scope, None)
|
||||
raise Unauthorized("Access token is invalid")
|
||||
|
||||
if not api_token:
|
||||
# Cache the null result to prevent cache penetration attacks
|
||||
ApiTokenCache.set(auth_token, scope, None)
|
||||
raise Unauthorized("Access token is invalid")
|
||||
ApiTokenCache.set(auth_token, scope, api_token)
|
||||
return api_token
|
||||
finally:
|
||||
lock.release()
|
||||
else:
|
||||
# Lock acquisition timeout - fallback to direct query
|
||||
logger.warning("Lock timeout for token: %s, proceeding with direct query", auth_token[:10])
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
current_time = naive_utc_now()
|
||||
update_token_last_used_at(auth_token, scope, current_time, session=session)
|
||||
|
||||
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
||||
api_token = session.scalar(stmt)
|
||||
|
||||
# Cache the valid token
|
||||
ApiTokenCache.set(auth_token, scope, api_token)
|
||||
if not api_token:
|
||||
ApiTokenCache.set(auth_token, scope, None)
|
||||
raise Unauthorized("Access token is invalid")
|
||||
|
||||
return api_token
|
||||
ApiTokenCache.set(auth_token, scope, api_token)
|
||||
return api_token
|
||||
except Exception as e:
|
||||
# Redis lock failure - fallback to direct query to ensure service availability
|
||||
logger.warning("Redis lock failed for token query: %s, proceeding anyway", e)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
current_time = naive_utc_now()
|
||||
update_token_last_used_at(auth_token, scope, current_time, session=session)
|
||||
|
||||
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
||||
api_token = session.scalar(stmt)
|
||||
|
||||
if not api_token:
|
||||
ApiTokenCache.set(auth_token, scope, None)
|
||||
raise Unauthorized("Access token is invalid")
|
||||
|
||||
ApiTokenCache.set(auth_token, scope, api_token)
|
||||
return api_token
|
||||
|
||||
|
||||
def _async_update_token_last_used_at(auth_token: str, scope: str | None):
|
||||
|
||||
@ -9,36 +9,33 @@ import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from extensions.ext_redis import redis_client, redis_fallback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CachedApiToken:
|
||||
class CachedApiToken(BaseModel):
|
||||
"""
|
||||
Simple data class to represent a cached API token.
|
||||
Pydantic model for cached API token data.
|
||||
|
||||
This is NOT a SQLAlchemy model instance, but a plain Python object
|
||||
This is NOT a SQLAlchemy model instance, but a plain Pydantic model
|
||||
that mimics the ApiToken model interface for read-only access.
|
||||
|
||||
Using Pydantic provides:
|
||||
- Automatic type validation
|
||||
- Better IDE support
|
||||
- Built-in serialization/deserialization
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
app_id: str | None,
|
||||
tenant_id: str | None,
|
||||
type: str,
|
||||
token: str,
|
||||
last_used_at: datetime | None,
|
||||
created_at: datetime | None,
|
||||
):
|
||||
self.id = id
|
||||
self.app_id = app_id
|
||||
self.tenant_id = tenant_id
|
||||
self.type = type
|
||||
self.token = token
|
||||
self.last_used_at = last_used_at
|
||||
self.created_at = created_at
|
||||
id: str
|
||||
app_id: str | None
|
||||
tenant_id: str | None
|
||||
type: str
|
||||
token: str
|
||||
last_used_at: datetime | None
|
||||
created_at: datetime | None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<CachedApiToken id={self.id} type={self.type}>"
|
||||
@ -77,11 +74,16 @@ class ApiTokenCache:
|
||||
Serialize ApiToken object to JSON string.
|
||||
|
||||
Args:
|
||||
api_token: ApiToken model instance
|
||||
api_token: ApiToken model instance or CachedApiToken
|
||||
|
||||
Returns:
|
||||
JSON string representation
|
||||
"""
|
||||
# If it's already a Pydantic model, use model_dump_json
|
||||
if isinstance(api_token, CachedApiToken):
|
||||
return api_token.model_dump_json()
|
||||
|
||||
# Otherwise, convert from SQLAlchemy model
|
||||
data = {
|
||||
"id": str(api_token.id),
|
||||
"app_id": str(api_token.app_id) if api_token.app_id else None,
|
||||
@ -96,7 +98,7 @@ class ApiTokenCache:
|
||||
@staticmethod
|
||||
def _deserialize_token(cached_data: str) -> Any:
|
||||
"""
|
||||
Deserialize JSON string back to a CachedApiToken object.
|
||||
Deserialize JSON string back to a CachedApiToken Pydantic model.
|
||||
|
||||
Args:
|
||||
cached_data: JSON string from cache
|
||||
@ -109,22 +111,10 @@ class ApiTokenCache:
|
||||
return None
|
||||
|
||||
try:
|
||||
data = json.loads(cached_data)
|
||||
|
||||
# Create a simple data object (NOT a SQLAlchemy model instance)
|
||||
# This is safe because it's just a plain Python object with attributes
|
||||
token_obj = CachedApiToken(
|
||||
id=data["id"],
|
||||
app_id=data["app_id"],
|
||||
tenant_id=data["tenant_id"],
|
||||
type=data["type"],
|
||||
token=data["token"],
|
||||
last_used_at=datetime.fromisoformat(data["last_used_at"]) if data["last_used_at"] else None,
|
||||
created_at=datetime.fromisoformat(data["created_at"]) if data["created_at"] else None,
|
||||
)
|
||||
|
||||
# Use Pydantic's model_validate_json for automatic validation
|
||||
token_obj = CachedApiToken.model_validate_json(cached_data)
|
||||
return token_obj
|
||||
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.warning("Failed to deserialize token from cache: %s", e)
|
||||
return None
|
||||
|
||||
@ -289,92 +279,3 @@ class ApiTokenCache:
|
||||
except Exception as e:
|
||||
logger.warning("Failed to delete token cache: %s", e)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=False)
|
||||
def invalidate_by_tenant(tenant_id: str) -> bool:
|
||||
"""
|
||||
Invalidate all API token caches for a specific tenant.
|
||||
Use this when tenant status changes or tokens are batch updated.
|
||||
|
||||
Uses a two-tier approach:
|
||||
1. Try to use tenant index (fast, O(n) where n = tenant's tokens)
|
||||
2. Fallback to full scan if index doesn't exist (slow, O(N) where N = all tokens)
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Try using tenant index first (efficient approach)
|
||||
index_key = f"tenant_tokens:{tenant_id}"
|
||||
cache_keys = redis_client.smembers(index_key)
|
||||
|
||||
if cache_keys:
|
||||
# Index exists - use it (fast path)
|
||||
deleted_count = 0
|
||||
for cache_key in cache_keys:
|
||||
if isinstance(cache_key, bytes):
|
||||
cache_key = cache_key.decode("utf-8")
|
||||
redis_client.delete(cache_key)
|
||||
deleted_count += 1
|
||||
|
||||
# Delete the index itself
|
||||
redis_client.delete(index_key)
|
||||
|
||||
logger.info(
|
||||
"Invalidated %d token cache entries for tenant: %s (via index)",
|
||||
deleted_count,
|
||||
tenant_id,
|
||||
)
|
||||
return True
|
||||
|
||||
# Index doesn't exist - fallback to scanning (slow path)
|
||||
logger.info("Tenant index not found, falling back to full scan for tenant: %s", tenant_id)
|
||||
|
||||
pattern = f"{CACHE_KEY_PREFIX}:*"
|
||||
cursor = 0
|
||||
deleted_count = 0
|
||||
checked_count = 0
|
||||
|
||||
while True:
|
||||
cursor, keys = redis_client.scan(cursor, match=pattern, count=100)
|
||||
if keys:
|
||||
for key in keys:
|
||||
checked_count += 1
|
||||
try:
|
||||
# Fetch and check if this token belongs to the tenant
|
||||
cached_data = redis_client.get(key)
|
||||
if cached_data:
|
||||
# Decode if bytes
|
||||
if isinstance(cached_data, bytes):
|
||||
cached_data = cached_data.decode("utf-8")
|
||||
|
||||
# Skip null values
|
||||
if cached_data == "null":
|
||||
continue
|
||||
|
||||
# Deserialize and check tenant_id
|
||||
data = json.loads(cached_data)
|
||||
if data.get("tenant_id") == tenant_id:
|
||||
redis_client.delete(key)
|
||||
deleted_count += 1
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
logger.warning("Failed to check cache key %s: %s", key, e)
|
||||
continue
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
"Invalidated %d token cache entries for tenant: %s (checked %d keys via scan)",
|
||||
deleted_count,
|
||||
tenant_id,
|
||||
checked_count,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning("Failed to invalidate tenant token cache: %s", e)
|
||||
return False
|
||||
|
||||
@ -167,43 +167,6 @@ class TestApiTokenCacheRedisIntegration:
|
||||
cache_keys = [m.decode('utf-8') if isinstance(m, bytes) else m for m in members]
|
||||
assert self.cache_key in cache_keys
|
||||
|
||||
def test_invalidate_by_tenant_via_index(self):
|
||||
"""Test tenant-wide cache invalidation using index (fast path)."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
tenant_id = "test-tenant-id"
|
||||
|
||||
# Create multiple tokens for the same tenant
|
||||
for i in range(3):
|
||||
token_value = f"test-token-{i}"
|
||||
mock_token = MagicMock()
|
||||
mock_token.id = f"test-id-{i}"
|
||||
mock_token.app_id = "test-app"
|
||||
mock_token.tenant_id = tenant_id
|
||||
mock_token.type = "app"
|
||||
mock_token.token = token_value
|
||||
mock_token.last_used_at = None
|
||||
mock_token.created_at = datetime.now()
|
||||
|
||||
ApiTokenCache.set(token_value, "app", mock_token)
|
||||
|
||||
# Verify all cached
|
||||
for i in range(3):
|
||||
key = f"api_token:app:test-token-{i}"
|
||||
assert redis_client.exists(key) == 1
|
||||
|
||||
# Invalidate by tenant
|
||||
result = ApiTokenCache.invalidate_by_tenant(tenant_id)
|
||||
assert result is True
|
||||
|
||||
# Verify all deleted
|
||||
for i in range(3):
|
||||
key = f"api_token:app:test-token-{i}"
|
||||
assert redis_client.exists(key) == 0
|
||||
|
||||
# Verify index also deleted
|
||||
assert redis_client.exists(f"tenant_tokens:{tenant_id}") == 0
|
||||
|
||||
def test_concurrent_cache_access(self):
|
||||
"""Test concurrent cache access doesn't cause issues."""
|
||||
import concurrent.futures
|
||||
|
||||
Reference in New Issue
Block a user