Files
ragflow/internal/cache/redis.go
Jin Hai 70e9743ef1 RAGFlow go API server (#13240)
# RAGFlow Go Implementation Plan 🚀

This repository tracks the progress of porting RAGFlow to Go. We'll
implement core features and provide performance comparisons between
Python and Go versions.

## Implementation Checklist

- [x] User Management APIs
- [x] Dataset Management Operations
- [x] Retrieval Test
- [x] Chat Management Operations
- [x] Infinity Go SDK

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
Co-authored-by: Yingfeng Zhang <yingfeng.zhang@gmail.com>
2026-03-04 19:17:16 +08:00

997 lines
24 KiB
Go

//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package cache
import (
"context"
"encoding/json"
"fmt"
"math"
"math/rand"
"strconv"
"sync"
"time"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
"ragflow/internal/logger"
"ragflow/internal/server"
)
var (
globalClient *RedisClient
once sync.Once
)
// RedisClient wraps go-redis client with additional utility methods
type RedisClient struct {
client *redis.Client
luaDeleteIfEqual *redis.Script
luaTokenBucket *redis.Script
luaAutoIncrement *redis.Script
config *server.RedisConfig
}
// RedisMsg represents a message from Redis Stream
type RedisMsg struct {
consumer *redis.Client
queueName string
groupName string
msgID string
message map[string]interface{}
}
// Lua scripts
const (
luaDeleteIfEqualScript = `
local current_value = redis.call('get', KEYS[1])
if current_value and current_value == ARGV[1] then
redis.call('del', KEYS[1])
return 1
end
return 0
`
luaTokenBucketScript = `
local key = KEYS[1]
local capacity = tonumber(ARGV[1])
local rate = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local cost = tonumber(ARGV[4])
local data = redis.call("HMGET", key, "tokens", "timestamp")
local tokens = tonumber(data[1])
local last_ts = tonumber(data[2])
if tokens == nil then
tokens = capacity
last_ts = now
end
local delta = math.max(0, now - last_ts)
tokens = math.min(capacity, tokens + delta * rate)
if tokens < cost then
return {0, tokens}
end
tokens = tokens - cost
redis.call("HMSET", key,
"tokens", tokens,
"timestamp", now
)
redis.call("EXPIRE", key, math.ceil(capacity / rate * 2))
return {1, tokens}
`
)
// Init initializes Redis client
func Init(cfg *server.RedisConfig) error {
var initErr error
once.Do(func() {
if cfg.Host == "" {
logger.Info("Redis host not configured, skipping Redis initialization")
return
}
client := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
Password: cfg.Password,
DB: cfg.DB,
})
// Test connection
ctx, cancel := context.WithTimeout(context.Background(), server.DefaultConnectTimeout)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
initErr = fmt.Errorf("failed to connect to Redis: %w", err)
return
}
globalClient = &RedisClient{
client: client,
config: cfg,
luaDeleteIfEqual: redis.NewScript(luaDeleteIfEqualScript),
luaTokenBucket: redis.NewScript(luaTokenBucketScript),
}
logger.Info("Redis client initialized",
zap.String("host", cfg.Host),
zap.Int("port", cfg.Port),
zap.Int("db", cfg.DB),
)
})
return initErr
}
// Get gets global Redis client instance
func Get() *RedisClient {
return globalClient
}
// Close closes Redis client
func Close() error {
if globalClient != nil && globalClient.client != nil {
return globalClient.client.Close()
}
return nil
}
// IsEnabled checks if Redis is enabled (configured and initialized)
func IsEnabled() bool {
return globalClient != nil && globalClient.client != nil
}
// Health checks if Redis is healthy
func (r *RedisClient) Health() bool {
if r.client == nil {
return false
}
ctx := context.Background()
if err := r.client.Ping(ctx).Err(); err != nil {
return false
}
testKey := "health_check_" + uuid.New().String()
testValue := "yy"
if err := r.client.Set(ctx, testKey, testValue, 3*time.Second).Err(); err != nil {
return false
}
val, err := r.client.Get(ctx, testKey).Result()
if err != nil || val != testValue {
return false
}
return true
}
// Info returns Redis server information
func (r *RedisClient) Info() map[string]interface{} {
if r.client == nil {
return nil
}
ctx := context.Background()
infoStr, err := r.client.Info(ctx).Result()
if err != nil {
logger.Warn("Failed to get Redis info", zap.Error(err))
return nil
}
// Parse info string to map
info := make(map[string]string)
lines := splitLines(infoStr)
for _, line := range lines {
if line == "" || line[0] == '#' {
continue
}
parts := splitN(line, ":", 2)
if len(parts) == 2 {
info[parts[0]] = parts[1]
}
}
result := map[string]interface{}{
"redis_version": info["redis_version"],
"server_mode": getServerMode(info),
"used_memory": info["used_memory_human"],
"total_system_memory": info["total_system_memory_human"],
"mem_fragmentation_ratio": info["mem_fragmentation_ratio"],
"connected_clients": parseInt(info["connected_clients"]),
"blocked_clients": parseInt(info["blocked_clients"]),
"instantaneous_ops_per_sec": parseInt(info["instantaneous_ops_per_sec"]),
"total_commands_processed": parseInt(info["total_commands_processed"]),
}
return result
}
func getServerMode(info map[string]string) string {
if mode, ok := info["server_mode"]; ok {
return mode
}
return info["redis_mode"]
}
func splitLines(s string) []string {
var lines []string
start := 0
for i := 0; i < len(s); i++ {
if s[i] == '\n' {
lines = append(lines, s[start:i])
start = i + 1
}
}
if start < len(s) {
lines = append(lines, s[start:])
}
return lines
}
func splitN(s, sep string, n int) []string {
if n <= 0 {
return []string{s}
}
idx := -1
for i := 0; i < len(s)-len(sep)+1; i++ {
if s[i:i+len(sep)] == sep {
idx = i
break
}
}
if idx == -1 {
return []string{s}
}
return []string{s[:idx], s[idx+len(sep):]}
}
func parseInt(s string) int {
v, _ := strconv.Atoi(s)
return v
}
// IsAlive checks if Redis client is alive
func (r *RedisClient) IsAlive() bool {
return r.client != nil
}
// Exist checks if key exists
func (r *RedisClient) Exist(key string) (bool, error) {
if r.client == nil {
return false, nil
}
ctx := context.Background()
exists, err := r.client.Exists(ctx, key).Result()
if err != nil {
logger.Warn("Redis Exist error", zap.String("key", key), zap.Error(err))
return false, err
}
return exists > 0, nil
}
// Get gets value by key
func (r *RedisClient) Get(key string) (string, error) {
if r.client == nil {
return "", nil
}
ctx := context.Background()
val, err := r.client.Get(ctx, key).Result()
if err == redis.Nil {
return "", nil
}
if err != nil {
logger.Warn("Redis Get error", zap.String("key", key), zap.Error(err))
return "", err
}
return val, nil
}
// SetObj sets object with JSON serialization
func (r *RedisClient) SetObj(key string, obj interface{}, exp time.Duration) bool {
if r.client == nil {
return false
}
ctx := context.Background()
data, err := json.Marshal(obj)
if err != nil {
logger.Warn("Redis SetObj marshal error", zap.String("key", key), zap.Error(err))
return false
}
if err := r.client.Set(ctx, key, data, exp).Err(); err != nil {
logger.Warn("Redis SetObj error", zap.String("key", key), zap.Error(err))
return false
}
return true
}
// GetObj gets and unmarshals object from Redis
func (r *RedisClient) GetObj(key string, dest interface{}) bool {
if r.client == nil {
return false
}
ctx := context.Background()
data, err := r.client.Get(ctx, key).Result()
if err == redis.Nil {
return false
}
if err != nil {
logger.Warn("Redis GetObj error", zap.String("key", key), zap.Error(err))
return false
}
if err := json.Unmarshal([]byte(data), dest); err != nil {
logger.Warn("Redis GetObj unmarshal error", zap.String("key", key), zap.Error(err))
return false
}
return true
}
// Set sets value with expiration
func (r *RedisClient) Set(key string, value string, exp time.Duration) bool {
if r.client == nil {
return false
}
ctx := context.Background()
if err := r.client.Set(ctx, key, value, exp).Err(); err != nil {
logger.Warn("Redis Set error", zap.String("key", key), zap.Error(err))
return false
}
return true
}
// SetNX sets value only if key does not exist
func (r *RedisClient) SetNX(key string, value string, exp time.Duration) bool {
if r.client == nil {
return false
}
ctx := context.Background()
ok, err := r.client.SetNX(ctx, key, value, exp).Result()
if err != nil {
logger.Warn("Redis SetNX error", zap.String("key", key), zap.Error(err))
return false
}
return ok
}
// GetOrCreateSecretKey atomically retrieves an existing key or creates a new one
// Uses Redis SETNX command to ensure atomicity across multiple goroutines/processes
func (r *RedisClient) GetOrCreateKey(key string, value string) (string, error) {
if r.client == nil {
return "", nil
}
ctx := context.Background()
// First, try to get the existing key
existingKey, err := r.client.Get(ctx, key).Result()
if err == nil {
logger.Warn("Redis Get error", zap.String("key", key), zap.Error(err))
// Successfully retrieved existing key
return existingKey, nil
}
// Use SETNX to atomically set the key only if it doesn't exist
// SETNX returns true if the key was set, false if it already existed
success, err := r.client.SetNX(ctx, key, value, 0).Result()
if err != nil {
return "", fmt.Errorf("failed to set key in Redis: %v", err)
}
if success {
// This goroutine successfully set the key
return value, nil
}
// SETNX failed, meaning another goroutine set the key concurrently
// Retrieve and return that key
finalKey, err := r.client.Get(ctx, key).Result()
if err != nil {
return "", fmt.Errorf("failed to get key set by another process: %v", err)
}
return finalKey, nil
}
// SAdd adds member to set
func (r *RedisClient) SAdd(key string, member string) bool {
if r.client == nil {
return false
}
ctx := context.Background()
if err := r.client.SAdd(ctx, key, member).Err(); err != nil {
logger.Warn("Redis SAdd error", zap.String("key", key), zap.Error(err))
return false
}
return true
}
// SRem removes member from set
func (r *RedisClient) SRem(key string, member string) bool {
if r.client == nil {
return false
}
ctx := context.Background()
if err := r.client.SRem(ctx, key, member).Err(); err != nil {
logger.Warn("Redis SRem error", zap.String("key", key), zap.Error(err))
return false
}
return true
}
// SMembers returns all members of a set
func (r *RedisClient) SMembers(key string) ([]string, error) {
if r.client == nil {
return nil, nil
}
ctx := context.Background()
members, err := r.client.SMembers(ctx, key).Result()
if err != nil {
logger.Warn("Redis SMembers error", zap.String("key", key), zap.Error(err))
return nil, err
}
return members, nil
}
// SIsMember checks if member exists in set
func (r *RedisClient) SIsMember(key string, member string) bool {
if r.client == nil {
return false
}
ctx := context.Background()
ok, err := r.client.SIsMember(ctx, key, member).Result()
if err != nil {
logger.Warn("Redis SIsMember error", zap.String("key", key), zap.Error(err))
return false
}
return ok
}
// ZAdd adds member with score to sorted set
func (r *RedisClient) ZAdd(key string, member string, score float64) bool {
if r.client == nil {
return false
}
ctx := context.Background()
if err := r.client.ZAdd(ctx, key, redis.Z{Score: score, Member: member}).Err(); err != nil {
logger.Warn("Redis ZAdd error", zap.String("key", key), zap.Error(err))
return false
}
return true
}
// ZCount returns count of members with score in range
func (r *RedisClient) ZCount(key string, min, max float64) int64 {
if r.client == nil {
return 0
}
ctx := context.Background()
count, err := r.client.ZCount(ctx, key, fmt.Sprintf("%f", min), fmt.Sprintf("%f", max)).Result()
if err != nil {
logger.Warn("Redis ZCount error", zap.String("key", key), zap.Error(err))
return 0
}
return count
}
// ZPopMin pops minimum score members from sorted set
func (r *RedisClient) ZPopMin(key string, count int) ([]redis.Z, error) {
if r.client == nil {
return nil, nil
}
ctx := context.Background()
members, err := r.client.ZPopMin(ctx, key, int64(count)).Result()
if err != nil {
logger.Warn("Redis ZPopMin error", zap.String("key", key), zap.Error(err))
return nil, err
}
return members, nil
}
// ZRangeByScore returns members with score in range
func (r *RedisClient) ZRangeByScore(key string, min, max float64) ([]string, error) {
if r.client == nil {
return nil, nil
}
ctx := context.Background()
members, err := r.client.ZRangeByScore(ctx, key, &redis.ZRangeBy{
Min: fmt.Sprintf("%f", min),
Max: fmt.Sprintf("%f", max),
}).Result()
if err != nil {
logger.Warn("Redis ZRangeByScore error", zap.String("key", key), zap.Error(err))
return nil, err
}
return members, nil
}
// ZRemRangeByScore removes members with score in range
func (r *RedisClient) ZRemRangeByScore(key string, min, max float64) int64 {
if r.client == nil {
return 0
}
ctx := context.Background()
count, err := r.client.ZRemRangeByScore(ctx, key, fmt.Sprintf("%f", min), fmt.Sprintf("%f", max)).Result()
if err != nil {
logger.Warn("Redis ZRemRangeByScore error", zap.String("key", key), zap.Error(err))
return 0
}
return count
}
// IncrBy increments key by increment
func (r *RedisClient) IncrBy(key string, increment int64) (int64, error) {
if r.client == nil {
return 0, nil
}
ctx := context.Background()
val, err := r.client.IncrBy(ctx, key, increment).Result()
if err != nil {
logger.Warn("Redis IncrBy error", zap.String("key", key), zap.Error(err))
return 0, err
}
return val, nil
}
// DecrBy decrements key by decrement
func (r *RedisClient) DecrBy(key string, decrement int64) (int64, error) {
if r.client == nil {
return 0, nil
}
ctx := context.Background()
val, err := r.client.DecrBy(ctx, key, decrement).Result()
if err != nil {
logger.Warn("Redis DecrBy error", zap.String("key", key), zap.Error(err))
return 0, err
}
return val, nil
}
// GenerateAutoIncrementID generates auto-increment ID
func (r *RedisClient) GenerateAutoIncrementID(keyPrefix string, namespace string, increment int64, ensureMinimum *int64) int64 {
if r.client == nil {
return -1
}
if keyPrefix == "" {
keyPrefix = "id_generator"
}
if namespace == "" {
namespace = "default"
}
if increment == 0 {
increment = 1
}
redisKey := fmt.Sprintf("%s:%s", keyPrefix, namespace)
ctx := context.Background()
// Check if key exists
exists, err := r.client.Exists(ctx, redisKey).Result()
if err != nil {
logger.Warn("Redis GenerateAutoIncrementID error", zap.Error(err))
return -1
}
if exists == 0 && ensureMinimum != nil {
startID := int64(math.Max(1, float64(*ensureMinimum)))
r.client.Set(ctx, redisKey, startID, 0)
return startID
}
// Get current value
if ensureMinimum != nil {
current, err := r.client.Get(ctx, redisKey).Int64()
if err == nil && current < *ensureMinimum {
r.client.Set(ctx, redisKey, *ensureMinimum, 0)
return *ensureMinimum
}
}
// Increment
nextID, err := r.client.IncrBy(ctx, redisKey, increment).Result()
if err != nil {
logger.Warn("Redis GenerateAutoIncrementID increment error", zap.Error(err))
return -1
}
return nextID
}
// Transaction sets key with NX flag (transaction-like behavior)
func (r *RedisClient) Transaction(key string, value string, exp time.Duration) bool {
if r.client == nil {
return false
}
ctx := context.Background()
pipe := r.client.Pipeline()
pipe.SetNX(ctx, key, value, exp)
_, err := pipe.Exec(ctx)
if err != nil {
logger.Warn("Redis Transaction error", zap.String("key", key), zap.Error(err))
return false
}
return true
}
// QueueProduct produces a message to Redis Stream
func (r *RedisClient) QueueProduct(queue string, message interface{}) bool {
if r.client == nil {
return false
}
ctx := context.Background()
for i := 0; i < 3; i++ {
data, err := json.Marshal(message)
if err != nil {
logger.Warn("Redis QueueProduct marshal error", zap.Error(err))
return false
}
_, err = r.client.XAdd(ctx, &redis.XAddArgs{
Stream: queue,
Values: map[string]interface{}{"message": string(data)},
}).Result()
if err == nil {
return true
}
logger.Warn("Redis QueueProduct error", zap.String("queue", queue), zap.Error(err))
time.Sleep(100 * time.Millisecond)
}
return false
}
// QueueConsumer consumes a message from Redis Stream
func (r *RedisClient) QueueConsumer(queueName, groupName, consumerName string, msgID string) (*RedisMsg, error) {
if r.client == nil {
return nil, nil
}
ctx := context.Background()
for i := 0; i < 3; i++ {
// Create consumer group if not exists
groups, err := r.client.XInfoGroups(ctx, queueName).Result()
if err != nil && err.Error() != "no such key" {
logger.Warn("Redis QueueConsumer XInfoGroups error", zap.Error(err))
}
groupExists := false
for _, g := range groups {
if g.Name == groupName {
groupExists = true
break
}
}
if !groupExists {
err = r.client.XGroupCreateMkStream(ctx, queueName, groupName, "0").Err()
if err != nil && err.Error() != "BUSYGROUP Consumer Group name already exists" {
logger.Warn("Redis QueueConsumer XGroupCreate error", zap.Error(err))
}
}
if msgID == "" {
msgID = ">"
}
messages, err := r.client.XReadGroup(ctx, &redis.XReadGroupArgs{
Group: groupName,
Consumer: consumerName,
Streams: []string{queueName, msgID},
Count: 1,
Block: 5 * time.Second,
}).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
logger.Warn("Redis QueueConsumer XReadGroup error", zap.Error(err))
time.Sleep(100 * time.Millisecond)
continue
}
if len(messages) == 0 || len(messages[0].Messages) == 0 {
return nil, nil
}
msg := messages[0].Messages[0]
var messageData map[string]interface{}
if msgStr, ok := msg.Values["message"].(string); ok {
json.Unmarshal([]byte(msgStr), &messageData)
}
return &RedisMsg{
consumer: r.client,
queueName: queueName,
groupName: groupName,
msgID: msg.ID,
message: messageData,
}, nil
}
return nil, nil
}
// Ack acknowledges the message
func (m *RedisMsg) Ack() bool {
if m.consumer == nil {
return false
}
ctx := context.Background()
err := m.consumer.XAck(ctx, m.queueName, m.groupName, m.msgID).Err()
if err != nil {
logger.Warn("RedisMsg Ack error", zap.Error(err))
return false
}
return true
}
// GetMessage returns the message data
func (m *RedisMsg) GetMessage() map[string]interface{} {
return m.message
}
// GetMsgID returns the message ID
func (m *RedisMsg) GetMsgID() string {
return m.msgID
}
// GetPendingMsg gets pending messages
func (r *RedisClient) GetPendingMsg(queue, groupName string) ([]redis.XPendingExt, error) {
if r.client == nil {
return nil, nil
}
ctx := context.Background()
msgs, err := r.client.XPendingExt(ctx, &redis.XPendingExtArgs{
Stream: queue,
Group: groupName,
Start: "-",
End: "+",
Count: 10,
}).Result()
if err != nil {
if err.Error() != "No such key" {
logger.Warn("Redis GetPendingMsg error", zap.Error(err))
}
return nil, err
}
return msgs, nil
}
// RequeueMsg requeues a message
func (r *RedisClient) RequeueMsg(queue, groupName, msgID string) {
if r.client == nil {
return
}
ctx := context.Background()
for i := 0; i < 3; i++ {
msgs, err := r.client.XRange(ctx, queue, msgID, msgID).Result()
if err != nil {
logger.Warn("Redis RequeueMsg XRange error", zap.Error(err))
time.Sleep(100 * time.Millisecond)
continue
}
if len(msgs) == 0 {
return
}
r.client.XAdd(ctx, &redis.XAddArgs{
Stream: queue,
Values: msgs[0].Values,
})
r.client.XAck(ctx, queue, groupName, msgID)
return
}
}
// QueueInfo returns queue group info
func (r *RedisClient) QueueInfo(queue, groupName string) (map[string]interface{}, error) {
if r.client == nil {
return nil, nil
}
ctx := context.Background()
for i := 0; i < 3; i++ {
groups, err := r.client.XInfoGroups(ctx, queue).Result()
if err != nil {
logger.Warn("Redis QueueInfo error", zap.Error(err))
time.Sleep(100 * time.Millisecond)
continue
}
for _, g := range groups {
if g.Name == groupName {
return map[string]interface{}{
"name": g.Name,
"consumers": g.Consumers,
"pending": g.Pending,
"last_delivered": g.LastDeliveredID,
}, nil
}
}
return nil, nil
}
return nil, nil
}
// DeleteIfEqual deletes key if its value equals expected value (atomic)
func (r *RedisClient) DeleteIfEqual(key, expectedValue string) bool {
if r.client == nil {
return false
}
ctx := context.Background()
result, err := r.luaDeleteIfEqual.Run(ctx, r.client, []string{key}, expectedValue).Result()
if err != nil {
logger.Warn("Redis DeleteIfEqual error", zap.Error(err))
return false
}
return result.(int64) == 1
}
// Delete deletes a key
func (r *RedisClient) Delete(key string) bool {
if r.client == nil {
return false
}
ctx := context.Background()
if err := r.client.Del(ctx, key).Err(); err != nil {
logger.Warn("Redis Delete error", zap.String("key", key), zap.Error(err))
return false
}
return true
}
// Expire sets expiration on a key
func (r *RedisClient) Expire(key string, exp time.Duration) bool {
if r.client == nil {
return false
}
ctx := context.Background()
if err := r.client.Expire(ctx, key, exp).Err(); err != nil {
logger.Warn("Redis Expire error", zap.String("key", key), zap.Error(err))
return false
}
return true
}
// TTL gets remaining time to live of a key
func (r *RedisClient) TTL(key string) time.Duration {
if r.client == nil {
return -2
}
ctx := context.Background()
ttl, err := r.client.TTL(ctx, key).Result()
if err != nil {
logger.Warn("Redis TTL error", zap.String("key", key), zap.Error(err))
return -2
}
return ttl
}
// DistributedLock distributed lock implementation
type DistributedLock struct {
client *RedisClient
lockKey string
lockValue string
timeout time.Duration
blockingTimeout time.Duration
}
// NewDistributedLock creates a new distributed lock
func NewDistributedLock(lockKey string, lockValue string, timeout time.Duration, blockingTimeout time.Duration) *DistributedLock {
if globalClient == nil {
return nil
}
if lockValue == "" {
lockValue = uuid.New().String()
}
return &DistributedLock{
client: globalClient,
lockKey: lockKey,
lockValue: lockValue,
timeout: timeout,
blockingTimeout: blockingTimeout,
}
}
// Acquire acquires the lock
func (l *DistributedLock) Acquire() bool {
if l.client == nil {
return false
}
// Delete if stale
l.client.DeleteIfEqual(l.lockKey, l.lockValue)
return l.client.SetNX(l.lockKey, l.lockValue, l.timeout)
}
// SpinAcquire keeps trying to acquire the lock
func (l *DistributedLock) SpinAcquire(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
l.client.DeleteIfEqual(l.lockKey, l.lockValue)
if l.client.SetNX(l.lockKey, l.lockValue, l.timeout) {
return nil
}
time.Sleep(10 * time.Second)
}
}
}
// Release releases the lock
func (l *DistributedLock) Release() bool {
if l.client == nil {
return false
}
return l.client.DeleteIfEqual(l.lockKey, l.lockValue)
}
// TokenBucket token bucket rate limiter
type TokenBucket struct {
client *RedisClient
key string
capacity float64
rate float64
}
// NewTokenBucket creates a new token bucket
func NewTokenBucket(key string, capacity, rate float64) *TokenBucket {
if globalClient == nil {
return nil
}
return &TokenBucket{
client: globalClient,
key: key,
capacity: capacity,
rate: rate,
}
}
// Allow checks if request is allowed
func (tb *TokenBucket) Allow(cost float64) (bool, float64) {
if tb.client == nil || tb.client.client == nil {
return true, 0
}
ctx := context.Background()
now := float64(time.Now().Unix())
result, err := tb.client.luaTokenBucket.Run(ctx, tb.client.client, []string{tb.key},
tb.capacity, tb.rate, now, cost).Result()
if err != nil {
logger.Warn("TokenBucket Allow error", zap.Error(err))
return true, 0
}
values := result.([]interface{})
allowed := values[0].(int64) == 1
tokens := values[1].(int64)
return allowed, float64(tokens)
}
// GetClient returns the underlying go-redis client for advanced usage
func (r *RedisClient) GetClient() *redis.Client {
return r.client
}
// RandomSleep sleeps for random duration between min and max milliseconds
func RandomSleep(minMs, maxMs int) {
duration := time.Duration(rand.Intn(maxMs-minMs)+minMs) * time.Millisecond
time.Sleep(duration)
}