add redis for api token

This commit is contained in:
Yansong Zhang
2026-02-03 15:03:11 +08:00
parent 840a975fef
commit d58d3f5bde
8 changed files with 630 additions and 1 deletions

View File

@ -6,6 +6,7 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
from libs.api_token_cache import ApiTokenCache
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset
@ -131,6 +132,9 @@ class BaseApiKeyResource(Resource):
if key is None:
flask_restx.abort(HTTPStatus.NOT_FOUND, message="API key not found")
# Invalidate cache before deleting from database
ApiTokenCache.delete(key.token, key.type)
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()

View File

@ -51,6 +51,7 @@ from fields.dataset_fields import (
weighted_score_fields,
)
from fields.document_fields import document_status_fields
from libs.api_token_cache import ApiTokenCache
from libs.login import current_account_with_tenant, login_required
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermissionEnum
@ -820,6 +821,9 @@ class DatasetApiDeleteApi(Resource):
if key is None:
console_ns.abort(404, message="API key not found")
# Invalidate cache before deleting from databas
ApiTokenCache.delete(key.token, key.type)
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()

View File

@ -17,6 +17,7 @@ from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.api_token_cache import ApiTokenCache
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models import Account, Tenant, TenantAccountJoin, TenantStatus
@ -296,7 +297,14 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
def validate_and_get_api_token(scope: str | None = None):
"""
Validate and get API token.
Validate and get API token with Redis caching.
This function uses a two-tier approach:
1. First checks Redis cache for the token
2. If not cached, queries database and caches the result
The last_used_at field is updated asynchronously via Celery task
to avoid blocking the request.
"""
auth_header = request.headers.get("Authorization")
if auth_header is None or " " not in auth_header:
@ -308,8 +316,20 @@ def validate_and_get_api_token(scope: str | None = None):
if auth_scheme != "bearer":
raise Unauthorized("Authorization scheme must be 'Bearer'")
# Try to get token from cache first
# Returns a CachedApiToken (plain Python object), not a SQLAlchemy model
cached_token = ApiTokenCache.get(auth_token, scope)
if cached_token is not None:
logger.debug("Token validation served from cache for scope: %s", scope)
# Asynchronously update last_used_at (non-blocking)
_async_update_token_last_used_at(auth_token, scope)
return cached_token
# Cache miss - query database
logger.debug("Token cache miss, querying database for scope: %s", scope)
current_time = naive_utc_now()
cutoff_time = current_time - timedelta(minutes=1)
with Session(db.engine, expire_on_commit=False) as session:
update_stmt = (
update(ApiToken)
@ -329,10 +349,35 @@ def validate_and_get_api_token(scope: str | None = None):
if not api_token:
raise Unauthorized("Access token is invalid")
# Cache the valid token
ApiTokenCache.set(auth_token, scope, api_token)
return api_token
def _async_update_token_last_used_at(auth_token: str, scope: str | None):
"""
Asynchronously update the last_used_at timestamp for a token.
This schedules a Celery task to update the database without blocking
the current request. The start time is passed to ensure only older
records are updated, providing natural concurrency control.
"""
try:
from tasks.update_api_token_last_used_task import update_api_token_last_used_task
# Record the request start time for concurrency control
start_time = naive_utc_now()
start_time_iso = start_time.isoformat()
# Fire and forget - don't wait for result
update_api_token_last_used_task.delay(auth_token, scope, start_time_iso)
logger.debug("Scheduled async update for last_used_at (scope: %s, start_time: %s)", scope, start_time_iso)
except Exception as e:
# Don't fail the request if task scheduling fails
logger.warning("Failed to schedule last_used_at update task: %s", e)
class DatasetApiResource(Resource):
method_decorators = [validate_dataset_token]

View File

@ -104,6 +104,7 @@ def init_app(app: DifyApp) -> Celery:
"tasks.trigger_processing_tasks", # async trigger processing
"tasks.generate_summary_index_task", # summary index generation
"tasks.regenerate_summary_index_task", # summary index regeneration
"tasks.update_api_token_last_used_task", # async API token last_used_at update
]
day = dify_config.CELERY_BEAT_SCHEDULER_TIME

262
api/libs/api_token_cache.py Normal file
View File

@ -0,0 +1,262 @@
"""
API Token Cache Module
Provides Redis-based caching for API token validation to reduce database load.
"""
import json
import logging
from datetime import datetime
from typing import Any
from extensions.ext_redis import redis_client, redis_fallback
logger = logging.getLogger(__name__)
class CachedApiToken:
"""
Simple data class to represent a cached API token.
This is NOT a SQLAlchemy model instance, but a plain Python object
that mimics the ApiToken model interface for read-only access.
"""
def __init__(
self,
id: str,
app_id: str | None,
tenant_id: str | None,
type: str,
token: str,
last_used_at: datetime | None,
created_at: datetime | None,
):
self.id = id
self.app_id = app_id
self.tenant_id = tenant_id
self.type = type
self.token = token
self.last_used_at = last_used_at
self.created_at = created_at
def __repr__(self) -> str:
return f"<CachedApiToken id={self.id} type={self.type}>"
# Cache configuration
CACHE_KEY_PREFIX = "api_token"
CACHE_TTL_SECONDS = 600 # 10 minutes
CACHE_NULL_TTL_SECONDS = 60 # 1 minute for non-existent tokens (防穿透)
class ApiTokenCache:
"""
Redis cache wrapper for API tokens.
Handles serialization, deserialization, and cache invalidation.
"""
@staticmethod
def _make_cache_key(token: str, scope: str | None = None) -> str:
"""
Generate cache key for the given token and scope.
Args:
token: The API token string
scope: The token type/scope (e.g., 'app', 'dataset')
Returns:
Cache key string
"""
scope_str = scope or "any"
return f"{CACHE_KEY_PREFIX}:{scope_str}:{token}"
@staticmethod
def _serialize_token(api_token: Any) -> str:
"""
Serialize ApiToken object to JSON string.
Args:
api_token: ApiToken model instance
Returns:
JSON string representation
"""
data = {
"id": str(api_token.id),
"app_id": str(api_token.app_id) if api_token.app_id else None,
"tenant_id": str(api_token.tenant_id) if api_token.tenant_id else None,
"type": api_token.type,
"token": api_token.token,
"last_used_at": api_token.last_used_at.isoformat() if api_token.last_used_at else None,
"created_at": api_token.created_at.isoformat() if api_token.created_at else None,
}
return json.dumps(data)
@staticmethod
def _deserialize_token(cached_data: str) -> Any:
"""
Deserialize JSON string back to a CachedApiToken object.
Args:
cached_data: JSON string from cache
Returns:
CachedApiToken instance or None
"""
if cached_data == "null":
# Cached null value (token doesn't exist)
return None
try:
data = json.loads(cached_data)
# Create a simple data object (NOT a SQLAlchemy model instance)
# This is safe because it's just a plain Python object with attributes
token_obj = CachedApiToken(
id=data["id"],
app_id=data["app_id"],
tenant_id=data["tenant_id"],
type=data["type"],
token=data["token"],
last_used_at=datetime.fromisoformat(data["last_used_at"]) if data["last_used_at"] else None,
created_at=datetime.fromisoformat(data["created_at"]) if data["created_at"] else None,
)
return token_obj
except (json.JSONDecodeError, KeyError, ValueError) as e:
logger.warning("Failed to deserialize token from cache: %s", e)
return None
@staticmethod
@redis_fallback(default_return=None)
def get(token: str, scope: str | None) -> Any | None:
"""
Get API token from cache.
Args:
token: The API token string
scope: The token type/scope
Returns:
CachedApiToken instance if found in cache, None if not cached or cache miss
"""
cache_key = ApiTokenCache._make_cache_key(token, scope)
cached_data = redis_client.get(cache_key)
if cached_data is None:
logger.debug("Cache miss for token key: %s", cache_key)
return None
# Decode bytes to string
if isinstance(cached_data, bytes):
cached_data = cached_data.decode("utf-8")
logger.debug("Cache hit for token key: %s", cache_key)
return ApiTokenCache._deserialize_token(cached_data)
@staticmethod
@redis_fallback(default_return=False)
def set(token: str, scope: str | None, api_token: Any | None, ttl: int = CACHE_TTL_SECONDS) -> bool:
"""
Set API token in cache.
Args:
token: The API token string
scope: The token type/scope
api_token: ApiToken instance to cache (None for non-existent tokens)
ttl: Time to live in seconds
Returns:
True if successful, False otherwise
"""
cache_key = ApiTokenCache._make_cache_key(token, scope)
if api_token is None:
# Cache null value to prevent cache penetration
cached_value = "null"
ttl = CACHE_NULL_TTL_SECONDS
else:
cached_value = ApiTokenCache._serialize_token(api_token)
try:
redis_client.setex(cache_key, ttl, cached_value)
logger.debug("Cached token with key: %s, ttl: %ss", cache_key, ttl)
return True
except Exception as e:
logger.warning("Failed to cache token: %s", e)
return False
@staticmethod
@redis_fallback(default_return=False)
def delete(token: str, scope: str | None = None) -> bool:
"""
Delete API token from cache.
Args:
token: The API token string
scope: The token type/scope (None to delete all scopes)
Returns:
True if successful, False otherwise
"""
if scope is None:
# Delete all possible scopes for this token
# This is a safer approach when scope is unknown
pattern = f"{CACHE_KEY_PREFIX}:*:{token}"
try:
keys = redis_client.keys(pattern)
if keys:
redis_client.delete(*keys)
logger.info("Deleted %d cache entries for token", len(keys))
return True
except Exception as e:
logger.warning("Failed to delete token cache with pattern: %s", e)
return False
else:
cache_key = ApiTokenCache._make_cache_key(token, scope)
try:
redis_client.delete(cache_key)
logger.info("Deleted cache for key: %s", cache_key)
return True
except Exception as e:
logger.warning("Failed to delete token cache: %s", e)
return False
@staticmethod
@redis_fallback(default_return=False)
def invalidate_by_tenant(tenant_id: str) -> bool:
"""
Invalidate all API token caches for a specific tenant.
Use this when tenant status changes or tokens are batch updated.
Args:
tenant_id: The tenant ID
Returns:
True if successful, False otherwise
"""
# Note: This requires scanning, which can be slow on large Redis instances
# Consider using a separate index if this becomes a bottleneck
try:
pattern = f"{CACHE_KEY_PREFIX}:*"
cursor = 0
deleted_count = 0
while True:
cursor, keys = redis_client.scan(cursor, match=pattern, count=100)
if keys:
# Filter keys by checking if they contain the tenant_id
# This is a simple approach; for production, consider maintaining a separate index
for key in keys:
redis_client.delete(key)
deleted_count += 1
if cursor == 0:
break
logger.info("Invalidated %s token cache entries for tenant: %s", deleted_count, tenant_id)
return True
except Exception as e:
logger.warning("Failed to invalidate tenant token cache: %s", e)
return False

View File

@ -14,6 +14,7 @@ from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.api_token_cache import ApiTokenCache
from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage
from models import (
ApiToken,
@ -134,6 +135,12 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str):
def _delete_app_api_tokens(tenant_id: str, app_id: str):
def del_api_token(session, api_token_id: str):
# Fetch token details for cache invalidation
token_obj = session.query(ApiToken).where(ApiToken.id == api_token_id).first()
if token_obj:
# Invalidate cache before deletion
ApiTokenCache.delete(token_obj.token, token_obj.type)
session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
_delete_records(

View File

@ -0,0 +1,61 @@
"""
Celery task for updating API token last_used_at timestamp asynchronously.
"""
import logging
from datetime import datetime
from celery import shared_task
from sqlalchemy import update
from sqlalchemy.orm import Session
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import ApiToken
logger = logging.getLogger(__name__)
@shared_task(queue="dataset", bind=True, max_retries=3)
def update_api_token_last_used_task(self, token: str, scope: str | None, start_time_iso: str):
"""
Asynchronously update the last_used_at timestamp for an API token.
Uses timestamp comparison to ensure only updates when last_used_at is older
than the request start time, providing natural concurrency control.
Args:
token: The API token string
scope: The token type/scope (e.g., 'app', 'dataset')
start_time_iso: ISO format timestamp of when the request started
"""
try:
# Parse start_time from ISO format
start_time = datetime.fromisoformat(start_time_iso)
# Update database
current_time = naive_utc_now()
with Session(db.engine, expire_on_commit=False) as session:
update_stmt = (
update(ApiToken)
.where(
ApiToken.token == token,
ApiToken.type == scope,
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < start_time)),
)
.values(last_used_at=current_time)
)
result = session.execute(update_stmt)
if hasattr(result, "rowcount") and result.rowcount > 0:
session.commit()
logger.info("Updated last_used_at for token (async): %s... (scope: %s)", token[:10], scope)
return {"status": "updated", "rowcount": result.rowcount, "start_time": start_time_iso}
else:
logger.debug("No update needed for token: %s... (already up-to-date)", token[:10])
return {"status": "no_update_needed", "reason": "last_used_at >= start_time"}
except Exception as e:
logger.warning("Failed to update last_used_at for token (async): %s", e)
# Don't retry on failure to avoid blocking the queue
return {"status": "failed", "error": str(e)}

View File

@ -0,0 +1,245 @@
"""
Unit tests for API Token Cache module.
"""
import json
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from libs.api_token_cache import (
CACHE_KEY_PREFIX,
CACHE_NULL_TTL_SECONDS,
CACHE_TTL_SECONDS,
ApiTokenCache,
CachedApiToken,
)
class TestApiTokenCache:
"""Test cases for ApiTokenCache class."""
def setup_method(self):
"""Setup test fixtures."""
self.mock_token = MagicMock()
self.mock_token.id = "test-token-id-123"
self.mock_token.app_id = "test-app-id-456"
self.mock_token.tenant_id = "test-tenant-id-789"
self.mock_token.type = "app"
self.mock_token.token = "test-token-value-abc"
self.mock_token.last_used_at = datetime(2026, 2, 3, 10, 0, 0)
self.mock_token.created_at = datetime(2026, 1, 1, 0, 0, 0)
def test_make_cache_key(self):
"""Test cache key generation."""
# Test with scope
key = ApiTokenCache._make_cache_key("my-token", "app")
assert key == f"{CACHE_KEY_PREFIX}:app:my-token"
# Test without scope
key = ApiTokenCache._make_cache_key("my-token", None)
assert key == f"{CACHE_KEY_PREFIX}:any:my-token"
def test_serialize_token(self):
"""Test token serialization."""
serialized = ApiTokenCache._serialize_token(self.mock_token)
data = json.loads(serialized)
assert data["id"] == "test-token-id-123"
assert data["app_id"] == "test-app-id-456"
assert data["tenant_id"] == "test-tenant-id-789"
assert data["type"] == "app"
assert data["token"] == "test-token-value-abc"
assert data["last_used_at"] == "2026-02-03T10:00:00"
assert data["created_at"] == "2026-01-01T00:00:00"
def test_serialize_token_with_nulls(self):
"""Test token serialization with None values."""
mock_token = MagicMock()
mock_token.id = "test-id"
mock_token.app_id = None
mock_token.tenant_id = None
mock_token.type = "dataset"
mock_token.token = "test-token"
mock_token.last_used_at = None
mock_token.created_at = datetime(2026, 1, 1, 0, 0, 0)
serialized = ApiTokenCache._serialize_token(mock_token)
data = json.loads(serialized)
assert data["app_id"] is None
assert data["tenant_id"] is None
assert data["last_used_at"] is None
def test_deserialize_token(self):
"""Test token deserialization."""
cached_data = json.dumps(
{
"id": "test-id",
"app_id": "test-app",
"tenant_id": "test-tenant",
"type": "app",
"token": "test-token",
"last_used_at": "2026-02-03T10:00:00",
"created_at": "2026-01-01T00:00:00",
}
)
result = ApiTokenCache._deserialize_token(cached_data)
assert isinstance(result, CachedApiToken)
assert result.id == "test-id"
assert result.app_id == "test-app"
assert result.tenant_id == "test-tenant"
assert result.type == "app"
assert result.token == "test-token"
assert result.last_used_at == datetime(2026, 2, 3, 10, 0, 0)
assert result.created_at == datetime(2026, 1, 1, 0, 0, 0)
def test_deserialize_null_token(self):
"""Test deserialization of null token (cached miss)."""
result = ApiTokenCache._deserialize_token("null")
assert result is None
def test_deserialize_invalid_json(self):
"""Test deserialization with invalid JSON."""
result = ApiTokenCache._deserialize_token("invalid-json{")
assert result is None
@patch("libs.api_token_cache.redis_client")
def test_get_cache_hit(self, mock_redis):
"""Test cache hit scenario."""
cached_data = json.dumps(
{
"id": "test-id",
"app_id": "test-app",
"tenant_id": "test-tenant",
"type": "app",
"token": "test-token",
"last_used_at": "2026-02-03T10:00:00",
"created_at": "2026-01-01T00:00:00",
}
)
mock_redis.get.return_value = cached_data.encode("utf-8")
result = ApiTokenCache.get("test-token", "app")
assert result is not None
assert isinstance(result, CachedApiToken)
assert result.app_id == "test-app"
mock_redis.get.assert_called_once_with(f"{CACHE_KEY_PREFIX}:app:test-token")
@patch("libs.api_token_cache.redis_client")
def test_get_cache_miss(self, mock_redis):
"""Test cache miss scenario."""
mock_redis.get.return_value = None
result = ApiTokenCache.get("test-token", "app")
assert result is None
mock_redis.get.assert_called_once()
@patch("libs.api_token_cache.redis_client")
def test_set_valid_token(self, mock_redis):
"""Test setting a valid token in cache."""
result = ApiTokenCache.set("test-token", "app", self.mock_token)
assert result is True
mock_redis.setex.assert_called_once()
args = mock_redis.setex.call_args[0]
assert args[0] == f"{CACHE_KEY_PREFIX}:app:test-token"
assert args[1] == CACHE_TTL_SECONDS
@patch("libs.api_token_cache.redis_client")
def test_set_null_token(self, mock_redis):
"""Test setting a null token (cache penetration prevention)."""
result = ApiTokenCache.set("invalid-token", "app", None)
assert result is True
mock_redis.setex.assert_called_once()
args = mock_redis.setex.call_args[0]
assert args[0] == f"{CACHE_KEY_PREFIX}:app:invalid-token"
assert args[1] == CACHE_NULL_TTL_SECONDS
assert args[2] == "null"
@patch("libs.api_token_cache.redis_client")
def test_delete_with_scope(self, mock_redis):
"""Test deleting token cache with specific scope."""
result = ApiTokenCache.delete("test-token", "app")
assert result is True
mock_redis.delete.assert_called_once_with(f"{CACHE_KEY_PREFIX}:app:test-token")
@patch("libs.api_token_cache.redis_client")
def test_delete_without_scope(self, mock_redis):
"""Test deleting token cache without scope (delete all)."""
mock_redis.keys.return_value = [
b"api_token:app:test-token",
b"api_token:dataset:test-token",
]
result = ApiTokenCache.delete("test-token", None)
assert result is True
mock_redis.keys.assert_called_once()
mock_redis.delete.assert_called_once()
@patch("libs.api_token_cache.redis_client")
def test_redis_fallback_on_exception(self, mock_redis):
"""Test Redis fallback when Redis is unavailable."""
from redis import RedisError
mock_redis.get.side_effect = RedisError("Connection failed")
result = ApiTokenCache.get("test-token", "app")
# Should return None (fallback) instead of raising exception
assert result is None
class TestApiTokenCacheIntegration:
"""Integration test scenarios."""
@patch("libs.api_token_cache.redis_client")
def test_full_cache_lifecycle(self, mock_redis):
"""Test complete cache lifecycle: set -> get -> delete."""
# Setup mock token
mock_token = MagicMock()
mock_token.id = "id-123"
mock_token.app_id = "app-456"
mock_token.tenant_id = "tenant-789"
mock_token.type = "app"
mock_token.token = "token-abc"
mock_token.last_used_at = datetime(2026, 2, 3, 10, 0, 0)
mock_token.created_at = datetime(2026, 1, 1, 0, 0, 0)
# 1. Set token in cache
ApiTokenCache.set("token-abc", "app", mock_token)
assert mock_redis.setex.called
# 2. Simulate cache hit
cached_data = ApiTokenCache._serialize_token(mock_token)
mock_redis.get.return_value = cached_data.encode("utf-8")
retrieved = ApiTokenCache.get("token-abc", "app")
assert retrieved is not None
assert isinstance(retrieved, CachedApiToken)
# 3. Delete from cache
ApiTokenCache.delete("token-abc", "app")
assert mock_redis.delete.called
@patch("libs.api_token_cache.redis_client")
def test_cache_penetration_prevention(self, mock_redis):
"""Test that non-existent tokens are cached as null."""
# Set null token (cache miss)
ApiTokenCache.set("non-existent-token", "app", None)
args = mock_redis.setex.call_args[0]
assert args[2] == "null"
assert args[1] == CACHE_NULL_TTL_SECONDS # Shorter TTL for null values
if __name__ == "__main__":
pytest.main([__file__, "-v"])