Skip to content

Commit

Permalink
[FIX] token bucket fill (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicolas-ggd committed Jun 29, 2024
1 parent f0c24f4 commit f87c1d8
Showing 1 changed file with 31 additions and 38 deletions.
69 changes: 31 additions & 38 deletions limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
b64 "encoding/base64"
"errors"
"log"
"math"
"net/http"
"os"
"sync"
Expand All @@ -15,7 +16,10 @@ import (
)

// define constant variable of keyPrefix to avoid duplicate key in Redis
const keyPrefix = "ls_prefix:"
const (
keyPrefix = "ls_prefix:"
lastRefillPrefix = "_lastRefillTime"
)

// RateLimiter is struct based on Redis
type RateLimiter struct {
Expand All @@ -29,7 +33,7 @@ type RateLimiter struct {
currentToken int64

// lastRefillTime represents time that this bucket fill operation was tried
lastRefillTime time.Time
refillInterval time.Duration

mutex sync.Mutex

Expand All @@ -46,12 +50,12 @@ func encodeKey(value string) string {
}

// NewRateLimiter to received and define new RateLimiter struct
func NewRateLimiter(client *redis.Client, rate, maxToken int64) *RateLimiter {
func NewRateLimiter(client *redis.Client, rate, maxToken int64, refillInterval time.Duration) *RateLimiter {
return &RateLimiter{
client: client,
rate: rate,
maxTokens: maxToken,
lastRefillTime: time.Now(),
refillInterval: refillInterval,
currentToken: maxToken,
logger: log.New(os.Stdout, "RateLimiter: ", log.Lmicroseconds),
}
Expand All @@ -67,23 +71,24 @@ func NewRateLimiter(client *redis.Client, rate, maxToken int64) *RateLimiter {
// Returns:
//
// bool: Returns true if the request is allowed, false otherwise.
func (rl *RateLimiter) IsRequestAllowed(key string, tokens int64) bool {
func (rl *RateLimiter) IsRequestAllowed(key string, token int64) bool {
// use mutex to avoid race condition
rl.mutex.Lock()
defer rl.mutex.Unlock()

// encode key
sEnc := keyPrefix + encodeKey(key)

// get current token count from Redis
tokenCount, err := rl.client.Get(context.Background(), sEnc).Int64()
if err != nil && !errors.Is(err, redis.Nil) {
rl.logger.Printf("Error getting token count from Redis: %v", err)
return false
}

// get last refill time from Redis
lastRefillTimeStr, err := rl.client.Get(context.Background(), sEnc+"_lastRefillTime").Result()
if errors.Is(err, redis.Nil) {
tokenCount = rl.maxTokens
}

lastRefillTimeStr, err := rl.client.Get(context.Background(), sEnc+lastRefillPrefix).Result()
var lastRefillTime time.Time
if err == nil {
lastRefillTime, err = time.Parse(time.RFC3339, lastRefillTimeStr)
Expand All @@ -94,27 +99,20 @@ func (rl *RateLimiter) IsRequestAllowed(key string, tokens int64) bool {
} else if !errors.Is(err, redis.Nil) {
rl.logger.Printf("Error getting last refill time from Redis: %v", err)
return false
} else {
lastRefillTime = time.Now()
}

// refill tokens
tokenCount, lastRefillTime = rl.refillBucket(lastRefillTime, tokenCount)

// update last refill time in Redis
rl.client.Set(context.Background(), sEnc+"_lastRefillTime", lastRefillTime.Format(time.RFC3339), 0)
tokenCount = rl.refill(tokenCount, lastRefillTime)

// check if enough tokens are available
if tokenCount > 0 {
// decrement token count
tokenCount--
// update token count in Redis
err = rl.client.Set(context.Background(), sEnc, tokenCount, 0).Err()
if err != nil {
rl.logger.Printf("Error setting token count in Redis: %v", err)
return false
}
if tokenCount >= token {
tokenCount -= token
rl.client.Set(context.Background(), sEnc, tokenCount, 0)
rl.client.Set(context.Background(), sEnc+lastRefillPrefix, time.Now().Format(time.RFC3339), 0)
return true
}

rl.client.Set(context.Background(), sEnc+lastRefillPrefix, time.Now().Format(time.RFC3339), 0)
return false
}

Expand All @@ -128,11 +126,11 @@ func (rl *RateLimiter) IsRequestAllowed(key string, tokens int64) bool {
// Returns:
//
// gin.HandlerFunc: A Gin handler function that can be used as middleware in the Gin router.
func RateLimiterMiddleware(limiter *RateLimiter, tokens int64) gin.HandlerFunc {
func RateLimiterMiddleware(limiter *RateLimiter) gin.HandlerFunc {
return func(c *gin.Context) {
ip := c.ClientIP()

if !limiter.IsRequestAllowed(ip, tokens) {
token := int64(1)
if !limiter.IsRequestAllowed(ip, token) {
limiter.logger.Printf("Rate limit exceeded for IP: %s", ip)
c.Header("X-RateLimit-Remaining", "0")
c.JSON(http.StatusTooManyRequests, gin.H{"error": "too many requests"})
Expand All @@ -144,18 +142,13 @@ func RateLimiterMiddleware(limiter *RateLimiter, tokens int64) gin.HandlerFunc {
}
}

// refillBucket function calculate time, when token bucket can refill
func (rl *RateLimiter) refillBucket(lastRefillTime time.Time, tokenCount int64) (int64, time.Time) {
func (rl *RateLimiter) refill(currentTokens int64, lastRefillTime time.Time) int64 {
now := time.Now()
duration := now.Sub(lastRefillTime)
elapsed := now.Sub(lastRefillTime)

// Calculate tokens to add based on elapsed time and rate
tokensToAdd := (duration.Nanoseconds() * rl.rate) / 1e9 // maybe this calculation isn't correct, but i try to avoid float64, because sometimes it not accuracy

tokenCount = tokenCount + tokensToAdd
if tokenCount > rl.maxTokens {
tokenCount = rl.maxTokens
}
// calculate time which each token needs to refill in token bucket
tokensToAdd := elapsed.Nanoseconds() / rl.refillInterval.Nanoseconds()
newTokens := int64(math.Min(float64(currentTokens+tokensToAdd), float64(rl.maxTokens)))

return tokenCount, now
return newTokens
}

0 comments on commit f87c1d8

Please sign in to comment.