diff --git a/limiter.go b/limiter.go index 12dad3d..324fa59 100644 --- a/limiter.go +++ b/limiter.go @@ -16,10 +16,6 @@ type LimitCounter interface { } func NewRateLimiter(requestLimit int, windowLength time.Duration, options ...Option) *rateLimiter { - return newRateLimiter(requestLimit, windowLength, options...) -} - -func newRateLimiter(requestLimit int, windowLength time.Duration, options ...Option) *rateLimiter { rl := &rateLimiter{ requestLimit: requestLimit, windowLength: windowLength, @@ -43,14 +39,10 @@ func newRateLimiter(requestLimit int, windowLength time.Duration, options ...Opt } if rl.limitCounter == nil { - rl.limitCounter = &localCounter{ - latestWindow: time.Now().UTC().Truncate(windowLength), - latestCounters: make(map[uint64]int), - previousCounters: make(map[uint64]int), - windowLength: windowLength, - } + rl.limitCounter = NewLocalLimitCounter(windowLength) + } else { + rl.limitCounter.Config(requestLimit, windowLength) } - rl.limitCounter.Config(requestLimit, windowLength) if rl.onRequestLimit == nil { rl.onRequestLimit = func(w http.ResponseWriter, r *http.Request) { diff --git a/local_counter.go b/local_counter.go index 87a6ca4..d03cd55 100644 --- a/local_counter.go +++ b/local_counter.go @@ -7,27 +7,27 @@ import ( "github.com/cespare/xxhash/v2" ) -var _ LimitCounter = &localCounter{} +// NewLocalLimitCounter creates an instance of localCounter, +// which is an in-memory implementation of http.LimitCounter. +func NewLocalLimitCounter(windowLength time.Duration) *localCounter { + return &localCounter{ + windowLength: windowLength, + latestWindow: time.Now().UTC().Truncate(windowLength), + latestCounters: make(map[uint64]int), + previousCounters: make(map[uint64]int), + } +} + +var _ LimitCounter = (*localCounter)(nil) type localCounter struct { + windowLength time.Duration latestWindow time.Time - previousCounters map[uint64]int latestCounters map[uint64]int - windowLength time.Duration + previousCounters map[uint64]int mu sync.RWMutex } -func (c *localCounter) Config(requestLimit int, windowLength time.Duration) { - c.mu.Lock() - defer c.mu.Unlock() - - c.windowLength = windowLength -} - -func (c *localCounter) Increment(key string, currentWindow time.Time) error { - return c.IncrementBy(key, currentWindow, 1) -} - func (c *localCounter) IncrementBy(key string, currentWindow time.Time, amount int) error { c.mu.Lock() defer c.mu.Unlock() @@ -60,6 +60,14 @@ func (c *localCounter) Get(key string, currentWindow, previousWindow time.Time) return 0, 0, nil } +// Config implements LimitCounter but is redundant. +func (c *localCounter) Config(requestLimit int, windowLength time.Duration) {} + +// Increment implements LimitCounter but is redundant. +func (c *localCounter) Increment(key string, currentWindow time.Time) error { + return c.IncrementBy(key, currentWindow, 1) +} + func (c *localCounter) evict(currentWindow time.Time) { if c.latestWindow == currentWindow { return @@ -73,7 +81,7 @@ func (c *localCounter) evict(currentWindow time.Time) { } c.latestWindow = currentWindow - // NOTE: Don't use clear() to keep backward-compatibility. + // NOTE: Don't use clear() to be compatible with older version of Go. c.previousCounters, c.latestCounters = make(map[uint64]int), make(map[uint64]int) } diff --git a/local_counter_test.go b/local_counter_test.go index 7f6fd32..b3ee478 100644 --- a/local_counter_test.go +++ b/local_counter_test.go @@ -1,4 +1,4 @@ -package httprate +package httprate_test import ( "fmt" @@ -7,18 +7,12 @@ import ( "testing" "time" + "github.com/go-chi/httprate" "golang.org/x/sync/errgroup" ) func TestLocalCounter(t *testing.T) { - limitCounter := &localCounter{ - latestWindow: time.Now().UTC().Truncate(time.Minute), - latestCounters: make(map[uint64]int), - previousCounters: make(map[uint64]int), - windowLength: time.Minute, - } - - limitCounter.Config(1000, time.Minute) + limitCounter := httprate.NewLocalLimitCounter(time.Minute) currentWindow := time.Now().UTC().Truncate(time.Minute) previousWindow := currentWindow.Add(-time.Minute) @@ -146,12 +140,7 @@ func TestLocalCounter(t *testing.T) { } func BenchmarkLocalCounter(b *testing.B) { - limitCounter := &localCounter{ - latestWindow: time.Now().UTC().Truncate(time.Minute), - latestCounters: make(map[uint64]int), - previousCounters: make(map[uint64]int), - windowLength: time.Minute, - } + limitCounter := httprate.NewLocalLimitCounter(time.Minute) currentWindow := time.Now().UTC().Truncate(time.Minute) previousWindow := currentWindow.Add(-time.Minute)