mirror of
https://github.com/langgenius/dify.git
synced 2026-02-22 19:15:47 +08:00
Modify to synchronize redis data to db regularly.
This commit is contained in:
@ -122,7 +122,7 @@ These commands assume you start from the repository root.
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q api_token_update,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
|
||||
```
|
||||
|
||||
1. Optional: start Celery Beat (scheduled tasks, in a new terminal).
|
||||
|
||||
@ -1155,6 +1155,16 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
default=0,
|
||||
)
|
||||
|
||||
# API token last_used_at batch update
|
||||
ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK: bool = Field(
|
||||
description="Enable periodic batch update of API token last_used_at timestamps",
|
||||
default=True,
|
||||
)
|
||||
API_TOKEN_LAST_USED_UPDATE_INTERVAL: int = Field(
|
||||
description="Interval in minutes for batch updating API token last_used_at (default 30)",
|
||||
default=30,
|
||||
)
|
||||
|
||||
# Trigger provider refresh (simple version)
|
||||
ENABLE_TRIGGER_PROVIDER_REFRESH_TASK: bool = Field(
|
||||
description="Enable trigger provider refresh poller",
|
||||
|
||||
@ -17,7 +17,6 @@ from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.api_token_cache import ApiTokenCache
|
||||
from libs.api_token_updater import update_token_last_used_at
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantStatus
|
||||
@ -321,8 +320,8 @@ def validate_and_get_api_token(scope: str | None = None):
|
||||
cached_token = ApiTokenCache.get(auth_token, scope)
|
||||
if cached_token is not None:
|
||||
logger.debug("Token validation served from cache for scope: %s", scope)
|
||||
# Asynchronously update last_used_at (non-blocking)
|
||||
_async_update_token_last_used_at(auth_token, scope)
|
||||
# Record usage in Redis for later batch update (no Celery task per request)
|
||||
_record_token_usage(auth_token, scope)
|
||||
return cached_token
|
||||
|
||||
# Cache miss - use Redis lock for single-flight mode
|
||||
@ -332,14 +331,14 @@ def validate_and_get_api_token(scope: str | None = None):
|
||||
|
||||
def _query_token_from_db(auth_token: str, scope: str | None) -> ApiToken:
|
||||
"""
|
||||
Query API token from database, update last_used_at, and cache the result.
|
||||
Query API token from database and cache the result.
|
||||
|
||||
last_used_at is NOT updated here -- it is handled by the periodic batch
|
||||
task via _record_token_usage().
|
||||
|
||||
Raises Unauthorized if token is invalid.
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
current_time = naive_utc_now()
|
||||
update_token_last_used_at(auth_token, scope, current_time, session=session)
|
||||
|
||||
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
||||
api_token = session.scalar(stmt)
|
||||
|
||||
@ -348,6 +347,8 @@ def _query_token_from_db(auth_token: str, scope: str | None) -> ApiToken:
|
||||
raise Unauthorized("Access token is invalid")
|
||||
|
||||
ApiTokenCache.set(auth_token, scope, api_token)
|
||||
# Record usage for later batch update
|
||||
_record_token_usage(auth_token, scope)
|
||||
return api_token
|
||||
|
||||
|
||||
@ -386,27 +387,19 @@ def _fetch_token_with_single_flight(auth_token: str, scope: str | None) -> ApiTo
|
||||
return _query_token_from_db(auth_token, scope)
|
||||
|
||||
|
||||
def _async_update_token_last_used_at(auth_token: str, scope: str | None):
|
||||
def _record_token_usage(auth_token: str, scope: str | None):
|
||||
"""
|
||||
Asynchronously update the last_used_at timestamp for a token.
|
||||
Record token usage in Redis for later batch update by a scheduled job.
|
||||
|
||||
This schedules a Celery task to update the database without blocking
|
||||
the current request. The update time is passed to ensure only older
|
||||
records are updated, providing natural concurrency control.
|
||||
Instead of dispatching a Celery task per request, we simply SET a key in Redis.
|
||||
A Celery Beat scheduled task will periodically scan these keys and batch-update
|
||||
last_used_at in the database.
|
||||
"""
|
||||
try:
|
||||
from tasks.update_api_token_last_used_task import update_api_token_last_used_task
|
||||
|
||||
# Record the update time for concurrency control
|
||||
update_time = naive_utc_now()
|
||||
update_time_iso = update_time.isoformat()
|
||||
|
||||
# Fire and forget - don't wait for result
|
||||
update_api_token_last_used_task.delay(auth_token, scope, update_time_iso)
|
||||
logger.debug("Scheduled async update for last_used_at (scope: %s, update_time: %s)", scope, update_time_iso)
|
||||
key = f"api_token_active:{scope}:{auth_token}"
|
||||
redis_client.set(key, naive_utc_now().isoformat(), ex=3600) # TTL 1 hour as safety net
|
||||
except Exception as e:
|
||||
# Don't fail the request if task scheduling fails
|
||||
logger.warning("Failed to schedule last_used_at update task: %s", e)
|
||||
logger.warning("Failed to record token usage: %s", e)
|
||||
|
||||
|
||||
class DatasetApiResource(Resource):
|
||||
|
||||
@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then
|
||||
if [[ -z "${CELERY_QUEUES}" ]]; then
|
||||
if [[ "${EDITION}" == "CLOUD" ]]; then
|
||||
# Cloud edition: separate queues for dataset and trigger tasks
|
||||
DEFAULT_QUEUES="api_token_update,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
|
||||
else
|
||||
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
|
||||
DEFAULT_QUEUES="api_token_update,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
|
||||
fi
|
||||
else
|
||||
DEFAULT_QUEUES="${CELERY_QUEUES}"
|
||||
|
||||
@ -104,7 +104,6 @@ def init_app(app: DifyApp) -> Celery:
|
||||
"tasks.trigger_processing_tasks", # async trigger processing
|
||||
"tasks.generate_summary_index_task", # summary index generation
|
||||
"tasks.regenerate_summary_index_task", # summary index regeneration
|
||||
"tasks.update_api_token_last_used_task", # async API token last_used_at update
|
||||
]
|
||||
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
|
||||
|
||||
@ -185,6 +184,14 @@ def init_app(app: DifyApp) -> Celery:
|
||||
"task": "schedule.trigger_provider_refresh_task.trigger_provider_refresh",
|
||||
"schedule": timedelta(minutes=dify_config.TRIGGER_PROVIDER_REFRESH_INTERVAL),
|
||||
}
|
||||
|
||||
if dify_config.ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK:
|
||||
imports.append("schedule.update_api_token_last_used_task")
|
||||
beat_schedule["batch_update_api_token_last_used"] = {
|
||||
"task": "schedule.update_api_token_last_used_task.batch_update_api_token_last_used",
|
||||
"schedule": timedelta(minutes=dify_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL),
|
||||
}
|
||||
|
||||
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
|
||||
|
||||
return celery_app
|
||||
|
||||
@ -1,76 +0,0 @@
|
||||
"""
|
||||
Unified API Token update utilities.
|
||||
|
||||
This module provides a centralized method for updating API token last_used_at
|
||||
to avoid code duplication between sync and async update paths.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import ApiToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def update_token_last_used_at(
|
||||
token: str, scope: str | None, update_time: datetime, session: Session | None = None
|
||||
) -> dict:
|
||||
"""
|
||||
Unified method to update API token last_used_at timestamp.
|
||||
|
||||
This method is used by both:
|
||||
1. Direct database update (cache miss scenario)
|
||||
2. Async Celery task (cache hit scenario)
|
||||
|
||||
Args:
|
||||
token: The API token string
|
||||
scope: The token type/scope (e.g., 'app', 'dataset')
|
||||
update_time: The time to use for the update (for concurrency control)
|
||||
session: Optional existing session to use (if None, creates new one)
|
||||
|
||||
Returns:
|
||||
Dict with status, rowcount, and other metadata
|
||||
"""
|
||||
current_time = naive_utc_now()
|
||||
|
||||
def _do_update(s: Session) -> dict:
|
||||
"""Execute the update within the session."""
|
||||
update_stmt = (
|
||||
update(ApiToken)
|
||||
.where(
|
||||
ApiToken.token == token,
|
||||
ApiToken.type == scope,
|
||||
# Only update if last_used_at is older than update_time
|
||||
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < update_time)),
|
||||
)
|
||||
.values(last_used_at=current_time)
|
||||
)
|
||||
result = s.execute(update_stmt)
|
||||
|
||||
rowcount = getattr(result, "rowcount", 0)
|
||||
if rowcount > 0:
|
||||
s.commit()
|
||||
logger.debug("Updated last_used_at for token: %s... (scope: %s)", token[:10], scope)
|
||||
return {"status": "updated", "rowcount": rowcount}
|
||||
else:
|
||||
logger.debug("No update needed for token: %s... (already up-to-date)", token[:10])
|
||||
return {"status": "no_update_needed", "reason": "last_used_at >= update_time"}
|
||||
|
||||
try:
|
||||
if session:
|
||||
# Use provided session (sync path)
|
||||
return _do_update(session)
|
||||
else:
|
||||
# Create new session (async path)
|
||||
with Session(db.engine, expire_on_commit=False) as new_session:
|
||||
return _do_update(new_session)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to update last_used_at for token: %s", e)
|
||||
return {"status": "failed", "error": str(e)}
|
||||
102
api/schedule/update_api_token_last_used_task.py
Normal file
102
api/schedule/update_api_token_last_used_task.py
Normal file
@ -0,0 +1,102 @@
|
||||
"""
|
||||
Scheduled task to batch-update API token last_used_at timestamps.
|
||||
|
||||
Instead of updating the database on every request, token usage is recorded
|
||||
in Redis as lightweight SET keys (api_token_active:{scope}:{token}).
|
||||
This task runs periodically (default every 30 minutes) to flush those
|
||||
records into the database in a single batch operation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import app
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import ApiToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ACTIVE_TOKEN_KEY_PREFIX = "api_token_active:"
|
||||
|
||||
|
||||
@app.celery.task(queue="dataset")
|
||||
def batch_update_api_token_last_used():
|
||||
"""
|
||||
Batch update last_used_at for all recently active API tokens.
|
||||
|
||||
Scans Redis for api_token_active:* keys, parses the token and scope
|
||||
from each key, and performs a batch database update.
|
||||
"""
|
||||
click.echo(click.style("batch_update_api_token_last_used: start.", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
updated_count = 0
|
||||
scanned_count = 0
|
||||
current_time = naive_utc_now()
|
||||
|
||||
try:
|
||||
# Collect all active token keys
|
||||
keys_to_process: list[str] = []
|
||||
for key in redis_client.scan_iter(match=f"{ACTIVE_TOKEN_KEY_PREFIX}*", count=200):
|
||||
if isinstance(key, bytes):
|
||||
key = key.decode("utf-8")
|
||||
keys_to_process.append(key)
|
||||
scanned_count += 1
|
||||
|
||||
if not keys_to_process:
|
||||
click.echo(click.style("batch_update_api_token_last_used: no active tokens found.", fg="yellow"))
|
||||
return
|
||||
|
||||
# Parse token info from keys: api_token_active:{scope}:{token}
|
||||
token_scope_pairs: list[tuple[str, str | None]] = []
|
||||
for key in keys_to_process:
|
||||
# Strip prefix
|
||||
suffix = key[len(ACTIVE_TOKEN_KEY_PREFIX):]
|
||||
# Split into scope:token (scope may be "None")
|
||||
parts = suffix.split(":", 1)
|
||||
if len(parts) == 2:
|
||||
scope_str, token = parts
|
||||
scope = None if scope_str == "None" else scope_str
|
||||
token_scope_pairs.append((token, scope))
|
||||
|
||||
# Batch update in database
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
for token, scope in token_scope_pairs:
|
||||
stmt = (
|
||||
update(ApiToken)
|
||||
.where(
|
||||
ApiToken.token == token,
|
||||
ApiToken.type == scope,
|
||||
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < current_time)),
|
||||
)
|
||||
.values(last_used_at=current_time)
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
rowcount = getattr(result, "rowcount", 0)
|
||||
if rowcount > 0:
|
||||
updated_count += 1
|
||||
|
||||
if updated_count > 0:
|
||||
session.commit()
|
||||
|
||||
# Delete processed keys from Redis
|
||||
if keys_to_process:
|
||||
redis_client.delete(*[k.encode("utf-8") if isinstance(k, str) else k for k in keys_to_process])
|
||||
|
||||
except Exception:
|
||||
logger.exception("batch_update_api_token_last_used failed")
|
||||
|
||||
elapsed = time.perf_counter() - start_at
|
||||
click.echo(
|
||||
click.style(
|
||||
f"batch_update_api_token_last_used: done. "
|
||||
f"scanned={scanned_count}, updated={updated_count}, elapsed={elapsed:.2f}s",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
@ -1,59 +0,0 @@
|
||||
"""
|
||||
Celery task for updating API token last_used_at timestamp asynchronously.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from celery import shared_task
|
||||
|
||||
from libs.api_token_updater import update_token_last_used_at
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="api_token_update", bind=True)
|
||||
def update_api_token_last_used_task(self, token: str, scope: str | None, update_time_iso: str):
|
||||
"""
|
||||
Asynchronously update the last_used_at timestamp for an API token.
|
||||
|
||||
Uses the unified update_token_last_used_at() method to avoid code duplication.
|
||||
|
||||
Queue: api_token_update (dedicated queue to isolate from other tasks and
|
||||
prevent accumulation in production environment)
|
||||
|
||||
Args:
|
||||
token: The API token string
|
||||
scope: The token type/scope (e.g., 'app', 'dataset')
|
||||
update_time_iso: ISO format timestamp for the update operation
|
||||
|
||||
Returns:
|
||||
Dict with status and metadata
|
||||
|
||||
Raises:
|
||||
Exception: Re-raises exceptions to allow Celery retry mechanism and monitoring
|
||||
"""
|
||||
try:
|
||||
# Parse update_time from ISO format
|
||||
update_time = datetime.fromisoformat(update_time_iso)
|
||||
|
||||
# Use unified update method
|
||||
result = update_token_last_used_at(token, scope, update_time, session=None)
|
||||
|
||||
if result["status"] == "updated":
|
||||
logger.info("Updated last_used_at for token (async): %s... (scope: %s)", token[:10], scope)
|
||||
elif result["status"] == "failed":
|
||||
# If update failed, log and raise for retry
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
logger.error("Failed to update last_used_at for token (async): %s", error_msg)
|
||||
raise Exception(f"Token update failed: {error_msg}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception:
|
||||
# Log the error with full context (logger.exception includes traceback automatically)
|
||||
logger.exception("Error in update_api_token_last_used_task (token: %s..., scope: %s)", token[:10], scope)
|
||||
|
||||
# Raise exception to let Celery handle retry and monitoring
|
||||
# This allows Flower and other monitoring tools to track failures
|
||||
raise
|
||||
@ -1,21 +1,19 @@
|
||||
"""
|
||||
Integration tests for API Token Cache with Redis and Celery.
|
||||
Integration tests for API Token Cache with Redis.
|
||||
|
||||
These tests require:
|
||||
- Redis server running
|
||||
- Test database configured
|
||||
- Celery worker running (for full integration test)
|
||||
"""
|
||||
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.api_token_cache import ApiTokenCache, CachedApiToken
|
||||
from libs.api_token_updater import update_token_last_used_at
|
||||
from models.model import ApiToken
|
||||
|
||||
|
||||
@ -38,18 +36,14 @@ class TestApiTokenCacheRedisIntegration:
|
||||
def _cleanup(self):
|
||||
"""Remove test data from Redis."""
|
||||
try:
|
||||
# Delete test cache key
|
||||
redis_client.delete(self.cache_key)
|
||||
# Delete any test tenant index
|
||||
redis_client.delete("tenant_tokens:test-tenant-id")
|
||||
# Delete any test locks
|
||||
redis_client.delete(f"api_token_last_used_lock:{self.test_scope}:{self.test_token}")
|
||||
redis_client.delete(f"api_token_active:{self.test_scope}:{self.test_token}")
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
def test_cache_set_and_get_with_real_redis(self):
|
||||
"""Test cache set and get operations with real Redis."""
|
||||
# Create a mock token
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_token = MagicMock()
|
||||
@ -92,30 +86,24 @@ class TestApiTokenCacheRedisIntegration:
|
||||
mock_token.last_used_at = None
|
||||
mock_token.created_at = datetime.now()
|
||||
|
||||
# Set in cache
|
||||
ApiTokenCache.set(self.test_token, self.test_scope, mock_token)
|
||||
|
||||
# Check TTL
|
||||
ttl = redis_client.ttl(self.cache_key)
|
||||
assert 595 <= ttl <= 600 # Should be around 600 seconds (10 minutes)
|
||||
|
||||
def test_cache_null_value_for_invalid_token(self):
|
||||
"""Test caching null value for invalid tokens"""
|
||||
# Cache null value
|
||||
"""Test caching null value for invalid tokens."""
|
||||
result = ApiTokenCache.set(self.test_token, self.test_scope, None)
|
||||
assert result is True
|
||||
|
||||
# Verify in Redis
|
||||
cached_data = redis_client.get(self.cache_key)
|
||||
assert cached_data == b"null"
|
||||
|
||||
# Get from cache should return None
|
||||
cached_token = ApiTokenCache.get(self.test_token, self.test_scope)
|
||||
assert cached_token is None
|
||||
|
||||
# Check TTL is shorter for null values
|
||||
ttl = redis_client.ttl(self.cache_key)
|
||||
assert 55 <= ttl <= 60 # Should be around 60 seconds
|
||||
assert 55 <= ttl <= 60
|
||||
|
||||
def test_cache_delete_with_real_redis(self):
|
||||
"""Test cache deletion with real Redis."""
|
||||
@ -130,15 +118,11 @@ class TestApiTokenCacheRedisIntegration:
|
||||
mock_token.last_used_at = None
|
||||
mock_token.created_at = datetime.now()
|
||||
|
||||
# Set in cache
|
||||
ApiTokenCache.set(self.test_token, self.test_scope, mock_token)
|
||||
assert redis_client.exists(self.cache_key) == 1
|
||||
|
||||
# Delete from cache
|
||||
result = ApiTokenCache.delete(self.test_token, self.test_scope)
|
||||
assert result is True
|
||||
|
||||
# Verify deleted
|
||||
assert redis_client.exists(self.cache_key) == 0
|
||||
|
||||
def test_tenant_index_creation(self):
|
||||
@ -155,14 +139,11 @@ class TestApiTokenCacheRedisIntegration:
|
||||
mock_token.last_used_at = None
|
||||
mock_token.created_at = datetime.now()
|
||||
|
||||
# Set in cache
|
||||
ApiTokenCache.set(self.test_token, self.test_scope, mock_token)
|
||||
|
||||
# Verify tenant index exists
|
||||
index_key = f"tenant_tokens:{tenant_id}"
|
||||
assert redis_client.exists(index_key) == 1
|
||||
|
||||
# Verify cache key is in the index
|
||||
members = redis_client.smembers(index_key)
|
||||
cache_keys = [m.decode("utf-8") if isinstance(m, bytes) else m for m in members]
|
||||
assert self.cache_key in cache_keys
|
||||
@ -173,7 +154,6 @@ class TestApiTokenCacheRedisIntegration:
|
||||
|
||||
tenant_id = "test-tenant-id"
|
||||
|
||||
# Create multiple tokens for the same tenant
|
||||
for i in range(3):
|
||||
token_value = f"test-token-{i}"
|
||||
mock_token = MagicMock()
|
||||
@ -187,21 +167,17 @@ class TestApiTokenCacheRedisIntegration:
|
||||
|
||||
ApiTokenCache.set(token_value, "app", mock_token)
|
||||
|
||||
# Verify all cached
|
||||
for i in range(3):
|
||||
key = f"api_token:app:test-token-{i}"
|
||||
assert redis_client.exists(key) == 1
|
||||
|
||||
# Invalidate by tenant
|
||||
result = ApiTokenCache.invalidate_by_tenant(tenant_id)
|
||||
assert result is True
|
||||
|
||||
# Verify all deleted
|
||||
for i in range(3):
|
||||
key = f"api_token:app:test-token-{i}"
|
||||
assert redis_client.exists(key) == 0
|
||||
|
||||
# Verify index also deleted
|
||||
assert redis_client.exists(f"tenant_tokens:{tenant_id}") == 0
|
||||
|
||||
def test_concurrent_cache_access(self):
|
||||
@ -218,151 +194,72 @@ class TestApiTokenCacheRedisIntegration:
|
||||
mock_token.last_used_at = None
|
||||
mock_token.created_at = datetime.now()
|
||||
|
||||
# Set once
|
||||
ApiTokenCache.set(self.test_token, self.test_scope, mock_token)
|
||||
|
||||
# Concurrent reads
|
||||
def get_from_cache():
|
||||
return ApiTokenCache.get(self.test_token, self.test_scope)
|
||||
|
||||
# Execute 50 concurrent reads
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [executor.submit(get_from_cache) for _ in range(50)]
|
||||
results = [f.result() for f in concurrent.futures.as_completed(futures)]
|
||||
|
||||
# All should succeed
|
||||
assert len(results) == 50
|
||||
assert all(r is not None for r in results)
|
||||
assert all(isinstance(r, CachedApiToken) for r in results)
|
||||
|
||||
|
||||
class TestApiTokenUpdaterIntegration:
|
||||
"""Integration tests for unified token updater."""
|
||||
class TestTokenUsageRecording:
|
||||
"""Tests for recording token usage in Redis (batch update approach)."""
|
||||
|
||||
@pytest.mark.usefixtures("db_session")
|
||||
def test_update_token_last_used_at_with_session(self, db_session):
|
||||
"""Test unified update method with provided session."""
|
||||
# Create a test token in database
|
||||
test_token = ApiToken()
|
||||
test_token.id = "test-updater-id"
|
||||
test_token.token = "test-updater-token"
|
||||
test_token.type = "app"
|
||||
test_token.app_id = "test-app"
|
||||
test_token.tenant_id = "test-tenant"
|
||||
test_token.last_used_at = datetime.now() - timedelta(minutes=10)
|
||||
test_token.created_at = datetime.now() - timedelta(days=30)
|
||||
|
||||
db_session.add(test_token)
|
||||
db_session.commit()
|
||||
def setup_method(self):
|
||||
self.test_token = "test-usage-token"
|
||||
self.test_scope = "app"
|
||||
self.active_key = f"api_token_active:{self.test_scope}:{self.test_token}"
|
||||
|
||||
def teardown_method(self):
|
||||
try:
|
||||
# Update using unified method
|
||||
start_time = datetime.now()
|
||||
result = update_token_last_used_at(test_token.token, test_token.type, start_time, session=db_session)
|
||||
redis_client.delete(self.active_key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Verify result
|
||||
assert result["status"] == "updated"
|
||||
assert result["rowcount"] == 1
|
||||
def test_record_token_usage_sets_redis_key(self):
|
||||
"""Test that _record_token_usage writes an active key to Redis."""
|
||||
from controllers.service_api.wraps import _record_token_usage
|
||||
|
||||
# Verify in database
|
||||
db_session.refresh(test_token)
|
||||
assert test_token.last_used_at >= start_time
|
||||
_record_token_usage(self.test_token, self.test_scope)
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
db_session.delete(test_token)
|
||||
db_session.commit()
|
||||
# Key should exist
|
||||
assert redis_client.exists(self.active_key) == 1
|
||||
|
||||
# Value should be an ISO timestamp
|
||||
value = redis_client.get(self.active_key)
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode("utf-8")
|
||||
datetime.fromisoformat(value) # Should not raise
|
||||
|
||||
@pytest.mark.celery_integration
|
||||
class TestCeleryTaskIntegration:
|
||||
"""
|
||||
Integration tests for Celery task.
|
||||
def test_record_token_usage_has_ttl(self):
|
||||
"""Test that active keys have a TTL as safety net."""
|
||||
from controllers.service_api.wraps import _record_token_usage
|
||||
|
||||
Requires Celery worker running with api_token_update queue.
|
||||
Run with: pytest -m celery_integration
|
||||
"""
|
||||
_record_token_usage(self.test_token, self.test_scope)
|
||||
|
||||
@pytest.mark.usefixtures("db_session")
|
||||
def test_celery_task_execution(self, db_session):
|
||||
"""Test Celery task can be executed successfully."""
|
||||
from tasks.update_api_token_last_used_task import update_api_token_last_used_task
|
||||
ttl = redis_client.ttl(self.active_key)
|
||||
assert 3595 <= ttl <= 3600 # ~1 hour
|
||||
|
||||
# Create a test token in database
|
||||
test_token = ApiToken()
|
||||
test_token.id = "test-celery-id"
|
||||
test_token.token = "test-celery-token"
|
||||
test_token.type = "app"
|
||||
test_token.app_id = "test-app"
|
||||
test_token.tenant_id = "test-tenant"
|
||||
test_token.last_used_at = datetime.now() - timedelta(minutes=10)
|
||||
test_token.created_at = datetime.now() - timedelta(days=30)
|
||||
def test_record_token_usage_overwrites(self):
|
||||
"""Test that repeated calls overwrite the same key (no accumulation)."""
|
||||
from controllers.service_api.wraps import _record_token_usage
|
||||
|
||||
db_session.add(test_token)
|
||||
db_session.commit()
|
||||
_record_token_usage(self.test_token, self.test_scope)
|
||||
first_value = redis_client.get(self.active_key)
|
||||
|
||||
try:
|
||||
# Send task
|
||||
start_time_iso = datetime.now().isoformat()
|
||||
result = update_api_token_last_used_task.delay(test_token.token, test_token.type, start_time_iso)
|
||||
time.sleep(0.01) # Tiny delay so timestamp differs
|
||||
|
||||
# Wait for task to complete (with timeout)
|
||||
task_result = result.get(timeout=10)
|
||||
_record_token_usage(self.test_token, self.test_scope)
|
||||
second_value = redis_client.get(self.active_key)
|
||||
|
||||
# Verify task executed
|
||||
assert task_result["status"] in ["updated", "no_update_needed"]
|
||||
|
||||
# Verify in database
|
||||
db_session.refresh(test_token)
|
||||
# last_used_at should be updated or already recent
|
||||
assert test_token.last_used_at is not None
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
db_session.delete(test_token)
|
||||
db_session.commit()
|
||||
|
||||
@pytest.mark.usefixtures("db_session")
|
||||
def test_concurrent_celery_tasks_with_redis_lock(self, db_session):
|
||||
"""Test multiple Celery tasks with Redis lock (防抖)."""
|
||||
from tasks.update_api_token_last_used_task import update_api_token_last_used_task
|
||||
|
||||
# Create a test token
|
||||
test_token = ApiToken()
|
||||
test_token.id = "test-concurrent-id"
|
||||
test_token.token = "test-concurrent-token"
|
||||
test_token.type = "app"
|
||||
test_token.app_id = "test-app"
|
||||
test_token.tenant_id = "test-tenant"
|
||||
test_token.last_used_at = datetime.now() - timedelta(minutes=10)
|
||||
test_token.created_at = datetime.now() - timedelta(days=30)
|
||||
|
||||
db_session.add(test_token)
|
||||
db_session.commit()
|
||||
|
||||
try:
|
||||
# Send 10 tasks concurrently
|
||||
start_time_iso = datetime.now().isoformat()
|
||||
tasks = []
|
||||
for _ in range(10):
|
||||
result = update_api_token_last_used_task.delay(test_token.token, test_token.type, start_time_iso)
|
||||
tasks.append(result)
|
||||
|
||||
# Wait for all tasks
|
||||
results = [task.get(timeout=15) for task in tasks]
|
||||
|
||||
# Count how many actually updated
|
||||
updated_count = sum(1 for r in results if r["status"] == "updated")
|
||||
skipped_count = sum(1 for r in results if r["status"] == "skipped")
|
||||
|
||||
# Due to Redis lock, most should be skipped
|
||||
assert skipped_count >= 8 # At least 8 out of 10 should be skipped
|
||||
assert updated_count <= 2 # At most 2 should actually update
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
db_session.delete(test_token)
|
||||
db_session.commit()
|
||||
# Key count should still be 1 (overwritten, not accumulated)
|
||||
assert redis_client.exists(self.active_key) == 1
|
||||
|
||||
|
||||
class TestEndToEndCacheFlow:
|
||||
@ -376,11 +273,9 @@ class TestEndToEndCacheFlow:
|
||||
2. Second request (cache hit) -> return from cache
|
||||
3. Verify Redis state
|
||||
"""
|
||||
|
||||
test_token_value = "test-e2e-token"
|
||||
test_scope = "app"
|
||||
|
||||
# Create test token in DB
|
||||
test_token = ApiToken()
|
||||
test_token.id = "test-e2e-id"
|
||||
test_token.token = test_token_value
|
||||
@ -397,7 +292,6 @@ class TestEndToEndCacheFlow:
|
||||
# Step 1: Cache miss - set token in cache
|
||||
ApiTokenCache.set(test_token_value, test_scope, test_token)
|
||||
|
||||
# Verify cached
|
||||
cache_key = f"api_token:{test_scope}:{test_token_value}"
|
||||
assert redis_client.exists(cache_key) == 1
|
||||
|
||||
@ -415,11 +309,9 @@ class TestEndToEndCacheFlow:
|
||||
# Step 4: Delete and verify cleanup
|
||||
ApiTokenCache.delete(test_token_value, test_scope)
|
||||
assert redis_client.exists(cache_key) == 0
|
||||
# Index should be cleaned up
|
||||
assert cache_key.encode() not in redis_client.smembers(index_key)
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
db_session.delete(test_token)
|
||||
db_session.commit()
|
||||
redis_client.delete(f"api_token:{test_scope}:{test_token_value}")
|
||||
@ -433,7 +325,6 @@ class TestEndToEndCacheFlow:
|
||||
test_token_value = "test-concurrent-token"
|
||||
test_scope = "app"
|
||||
|
||||
# Setup cache
|
||||
mock_token = MagicMock()
|
||||
mock_token.id = "concurrent-id"
|
||||
mock_token.app_id = "test-app"
|
||||
@ -446,7 +337,6 @@ class TestEndToEndCacheFlow:
|
||||
ApiTokenCache.set(test_token_value, test_scope, mock_token)
|
||||
|
||||
try:
|
||||
# Simulate 100 concurrent cache reads
|
||||
def read_cache():
|
||||
return ApiTokenCache.get(test_token_value, test_scope)
|
||||
|
||||
@ -456,18 +346,12 @@ class TestEndToEndCacheFlow:
|
||||
results = [f.result() for f in concurrent.futures.as_completed(futures)]
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# All should succeed
|
||||
assert len(results) == 100
|
||||
assert all(r is not None for r in results)
|
||||
|
||||
# Should be fast (< 1 second for 100 reads)
|
||||
assert elapsed < 1.0, f"Too slow: {elapsed}s for 100 cache reads"
|
||||
|
||||
print(f"\n✓ 100 concurrent cache reads in {elapsed:.3f}s")
|
||||
print(f"✓ Average: {(elapsed / 100) * 1000:.2f}ms per read")
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
ApiTokenCache.delete(test_token_value, test_scope)
|
||||
redis_client.delete(f"tenant_tokens:{mock_token.tenant_id}")
|
||||
|
||||
@ -480,29 +364,22 @@ class TestRedisFailover:
|
||||
"""Test system degrades gracefully when Redis is unavailable."""
|
||||
from redis import RedisError
|
||||
|
||||
# Simulate Redis failure
|
||||
mock_redis.get.side_effect = RedisError("Connection failed")
|
||||
mock_redis.setex.side_effect = RedisError("Connection failed")
|
||||
|
||||
# Cache operations should not raise exceptions
|
||||
result_get = ApiTokenCache.get("test-token", "app")
|
||||
assert result_get is None # Returns None (fallback)
|
||||
assert result_get is None
|
||||
|
||||
result_set = ApiTokenCache.set("test-token", "app", None)
|
||||
assert result_set is False # Returns False (fallback)
|
||||
|
||||
# Application should continue working (using database directly)
|
||||
assert result_set is False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run integration tests
|
||||
pytest.main(
|
||||
[
|
||||
__file__,
|
||||
"-v",
|
||||
"-s",
|
||||
"--tb=short",
|
||||
"-m",
|
||||
"not celery_integration", # Skip Celery tests by default
|
||||
]
|
||||
)
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
Unit tests for API Token Cache module.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@ -42,10 +43,8 @@ class TestApiTokenCache:
|
||||
|
||||
def test_serialize_token(self):
|
||||
"""Test token serialization."""
|
||||
import orjson
|
||||
|
||||
serialized = ApiTokenCache._serialize_token(self.mock_token)
|
||||
data = orjson.loads(serialized) # orjson to parse bytes
|
||||
data = json.loads(serialized)
|
||||
|
||||
assert data["id"] == "test-token-id-123"
|
||||
assert data["app_id"] == "test-app-id-456"
|
||||
@ -57,8 +56,6 @@ class TestApiTokenCache:
|
||||
|
||||
def test_serialize_token_with_nulls(self):
|
||||
"""Test token serialization with None values."""
|
||||
import orjson
|
||||
|
||||
mock_token = MagicMock()
|
||||
mock_token.id = "test-id"
|
||||
mock_token.app_id = None
|
||||
@ -69,7 +66,7 @@ class TestApiTokenCache:
|
||||
mock_token.created_at = datetime(2026, 1, 1, 0, 0, 0)
|
||||
|
||||
serialized = ApiTokenCache._serialize_token(mock_token)
|
||||
data = orjson.loads(serialized) # orjson to parse bytes
|
||||
data = json.loads(serialized)
|
||||
|
||||
assert data["app_id"] is None
|
||||
assert data["tenant_id"] is None
|
||||
@ -77,9 +74,7 @@ class TestApiTokenCache:
|
||||
|
||||
def test_deserialize_token(self):
|
||||
"""Test token deserialization."""
|
||||
import orjson
|
||||
|
||||
cached_data = orjson.dumps(
|
||||
cached_data = json.dumps(
|
||||
{
|
||||
"id": "test-id",
|
||||
"app_id": "test-app",
|
||||
@ -115,9 +110,7 @@ class TestApiTokenCache:
|
||||
@patch("libs.api_token_cache.redis_client")
|
||||
def test_get_cache_hit(self, mock_redis):
|
||||
"""Test cache hit scenario."""
|
||||
import orjson
|
||||
|
||||
cached_data = orjson.dumps(
|
||||
cached_data = json.dumps(
|
||||
{
|
||||
"id": "test-id",
|
||||
"app_id": "test-app",
|
||||
@ -127,8 +120,8 @@ class TestApiTokenCache:
|
||||
"last_used_at": "2026-02-03T10:00:00",
|
||||
"created_at": "2026-01-01T00:00:00",
|
||||
}
|
||||
)
|
||||
mock_redis.get.return_value = cached_data # orjson returns bytes
|
||||
).encode("utf-8")
|
||||
mock_redis.get.return_value = cached_data
|
||||
|
||||
result = ApiTokenCache.get("test-token", "app")
|
||||
|
||||
@ -168,7 +161,7 @@ class TestApiTokenCache:
|
||||
args = mock_redis.setex.call_args[0]
|
||||
assert args[0] == f"{CACHE_KEY_PREFIX}:app:invalid-token"
|
||||
assert args[1] == CACHE_NULL_TTL_SECONDS
|
||||
assert args[2] == b"null" # orjson returns bytes
|
||||
assert args[2] == b"null"
|
||||
|
||||
@patch("libs.api_token_cache.redis_client")
|
||||
def test_delete_with_scope(self, mock_redis):
|
||||
@ -238,7 +231,7 @@ class TestApiTokenCacheIntegration:
|
||||
|
||||
# 2. Simulate cache hit
|
||||
cached_data = ApiTokenCache._serialize_token(mock_token)
|
||||
mock_redis.get.return_value = cached_data # Already bytes from orjson
|
||||
mock_redis.get.return_value = cached_data # bytes from model_dump_json().encode()
|
||||
|
||||
retrieved = ApiTokenCache.get("token-abc", "app")
|
||||
assert retrieved is not None
|
||||
@ -255,7 +248,7 @@ class TestApiTokenCacheIntegration:
|
||||
ApiTokenCache.set("non-existent-token", "app", None)
|
||||
|
||||
args = mock_redis.setex.call_args[0]
|
||||
assert args[2] == b"null" # orjson returns bytes
|
||||
assert args[2] == b"null"
|
||||
assert args[1] == CACHE_NULL_TTL_SECONDS # Shorter TTL for null values
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user