Skip to content

Commit

Permalink
Export in-memory counter for external use (#39)
Browse files Browse the repository at this point in the history
httprate.NewLocalLimitCounter(windowLength)
  • Loading branch information
VojtechVitek authored Jul 25, 2024
1 parent 05a79e9 commit 9e50ad6
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 41 deletions.
14 changes: 3 additions & 11 deletions limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ type LimitCounter interface {
}

func NewRateLimiter(requestLimit int, windowLength time.Duration, options ...Option) *rateLimiter {
return newRateLimiter(requestLimit, windowLength, options...)
}

func newRateLimiter(requestLimit int, windowLength time.Duration, options ...Option) *rateLimiter {
rl := &rateLimiter{
requestLimit: requestLimit,
windowLength: windowLength,
Expand All @@ -43,14 +39,10 @@ func newRateLimiter(requestLimit int, windowLength time.Duration, options ...Opt
}

if rl.limitCounter == nil {
rl.limitCounter = &localCounter{
latestWindow: time.Now().UTC().Truncate(windowLength),
latestCounters: make(map[uint64]int),
previousCounters: make(map[uint64]int),
windowLength: windowLength,
}
rl.limitCounter = NewLocalLimitCounter(windowLength)
} else {
rl.limitCounter.Config(requestLimit, windowLength)
}
rl.limitCounter.Config(requestLimit, windowLength)

if rl.onRequestLimit == nil {
rl.onRequestLimit = func(w http.ResponseWriter, r *http.Request) {
Expand Down
38 changes: 23 additions & 15 deletions local_counter.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,27 @@ import (
"github.com/cespare/xxhash/v2"
)

var _ LimitCounter = &localCounter{}
// NewLocalLimitCounter creates an instance of localCounter,
// which is an in-memory implementation of http.LimitCounter.
func NewLocalLimitCounter(windowLength time.Duration) *localCounter {
return &localCounter{
windowLength: windowLength,
latestWindow: time.Now().UTC().Truncate(windowLength),
latestCounters: make(map[uint64]int),
previousCounters: make(map[uint64]int),
}
}

var _ LimitCounter = (*localCounter)(nil)

type localCounter struct {
windowLength time.Duration
latestWindow time.Time
previousCounters map[uint64]int
latestCounters map[uint64]int
windowLength time.Duration
previousCounters map[uint64]int
mu sync.RWMutex
}

func (c *localCounter) Config(requestLimit int, windowLength time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()

c.windowLength = windowLength
}

func (c *localCounter) Increment(key string, currentWindow time.Time) error {
return c.IncrementBy(key, currentWindow, 1)
}

func (c *localCounter) IncrementBy(key string, currentWindow time.Time, amount int) error {
c.mu.Lock()
defer c.mu.Unlock()
Expand Down Expand Up @@ -60,6 +60,14 @@ func (c *localCounter) Get(key string, currentWindow, previousWindow time.Time)
return 0, 0, nil
}

// Config implements LimitCounter but is redundant.
func (c *localCounter) Config(requestLimit int, windowLength time.Duration) {}

// Increment implements LimitCounter but is redundant.
func (c *localCounter) Increment(key string, currentWindow time.Time) error {
return c.IncrementBy(key, currentWindow, 1)
}

func (c *localCounter) evict(currentWindow time.Time) {
if c.latestWindow == currentWindow {
return
Expand All @@ -73,7 +81,7 @@ func (c *localCounter) evict(currentWindow time.Time) {
}

c.latestWindow = currentWindow
// NOTE: Don't use clear() to keep backward-compatibility.
// NOTE: Don't use clear() to be compatible with older version of Go.
c.previousCounters, c.latestCounters = make(map[uint64]int), make(map[uint64]int)
}

Expand Down
19 changes: 4 additions & 15 deletions local_counter_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package httprate
package httprate_test

import (
"fmt"
Expand All @@ -7,18 +7,12 @@ import (
"testing"
"time"

"github.com/go-chi/httprate"
"golang.org/x/sync/errgroup"
)

func TestLocalCounter(t *testing.T) {
limitCounter := &localCounter{
latestWindow: time.Now().UTC().Truncate(time.Minute),
latestCounters: make(map[uint64]int),
previousCounters: make(map[uint64]int),
windowLength: time.Minute,
}

limitCounter.Config(1000, time.Minute)
limitCounter := httprate.NewLocalLimitCounter(time.Minute)

currentWindow := time.Now().UTC().Truncate(time.Minute)
previousWindow := currentWindow.Add(-time.Minute)
Expand Down Expand Up @@ -146,12 +140,7 @@ func TestLocalCounter(t *testing.T) {
}

func BenchmarkLocalCounter(b *testing.B) {
limitCounter := &localCounter{
latestWindow: time.Now().UTC().Truncate(time.Minute),
latestCounters: make(map[uint64]int),
previousCounters: make(map[uint64]int),
windowLength: time.Minute,
}
limitCounter := httprate.NewLocalLimitCounter(time.Minute)

currentWindow := time.Now().UTC().Truncate(time.Minute)
previousWindow := currentWindow.Add(-time.Minute)
Expand Down

0 comments on commit 9e50ad6

Please sign in to comment.