diff --git a/context.go b/context.go new file mode 100644 index 0000000..8db9e9a --- /dev/null +++ b/context.go @@ -0,0 +1,16 @@ +package httprate + +import "context" + +var incrementKey = &struct{}{} + +func WithIncrement(ctx context.Context, value int) context.Context { + return context.WithValue(ctx, incrementKey, value) +} + +func getIncrement(ctx context.Context) int { + if value, ok := ctx.Value(incrementKey).(int); ok { + return value + } + return 1 +} diff --git a/limiter.go b/limiter.go index fcbce6d..5f55ee6 100644 --- a/limiter.go +++ b/limiter.go @@ -13,6 +13,7 @@ import ( type LimitCounter interface { Config(requestLimit int, windowLength time.Duration) Increment(key string, currentWindow time.Time) error + IncrementBy(key string, currentWindow time.Time, amount int) error Get(key string, currentWindow, previousWindow time.Time) (int, int, error) } @@ -119,7 +120,7 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler { return } - err = l.limitCounter.Increment(key, currentWindow) + err = l.limitCounter.IncrementBy(key, currentWindow, getIncrement(r.Context())) if err != nil { l.mu.Unlock() http.Error(w, err.Error(), http.StatusInternalServerError) @@ -152,6 +153,10 @@ func (c *localCounter) Config(requestLimit int, windowLength time.Duration) { } func (c *localCounter) Increment(key string, currentWindow time.Time) error { + return c.IncrementBy(key, currentWindow, 1) +} + +func (c *localCounter) IncrementBy(key string, currentWindow time.Time, amount int) error { c.evict() c.mu.Lock() @@ -164,7 +169,7 @@ func (c *localCounter) Increment(key string, currentWindow time.Time) error { v = &count{} c.counters[hkey] = v } - v.value += 1 + v.value += amount v.updatedAt = time.Now() return nil diff --git a/limiter_test.go b/limiter_test.go index e4724a5..b844630 100644 --- a/limiter_test.go +++ b/limiter_test.go @@ -49,6 +49,45 @@ func TestLimit(t *testing.T) { } } +func TestWithIncrement(t *testing.T) { + type test struct { + name string + requestsLimit int + windowLength time.Duration + respCodes []int + } + tests := []test{ + { + name: "no-block", + requestsLimit: 3, + windowLength: 4 * time.Second, + respCodes: []int{200, 200, 429}, + }, + { + name: "block", + requestsLimit: 3, + windowLength: 2 * time.Second, + respCodes: []int{200, 200, 429, 429}, + }, + } + for i, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + router := httprate.LimitAll(tt.requestsLimit, tt.windowLength)(h) + + for _, code := range tt.respCodes { + req := httptest.NewRequest("GET", "/", nil) + req = req.WithContext(httprate.WithIncrement(req.Context(), 2)) + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + if respCode := recorder.Result().StatusCode; respCode != code { + t.Errorf("resp.StatusCode(%v) = %v, want %v", i, respCode, code) + } + } + }) + } +} + func TestLimitHandler(t *testing.T) { type test struct { name string