Skip to content

Commit

Permalink
Add IncrementBy (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
klaidliadon authored Dec 6, 2023
1 parent 2aed83f commit 3327e65
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 2 deletions.
16 changes: 16 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package httprate

import "context"

var incrementKey = &struct{}{}

func WithIncrement(ctx context.Context, value int) context.Context {
return context.WithValue(ctx, incrementKey, value)
}

func getIncrement(ctx context.Context) int {
if value, ok := ctx.Value(incrementKey).(int); ok {
return value
}
return 1
}
9 changes: 7 additions & 2 deletions limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
type LimitCounter interface {
Config(requestLimit int, windowLength time.Duration)
Increment(key string, currentWindow time.Time) error
IncrementBy(key string, currentWindow time.Time, amount int) error
Get(key string, currentWindow, previousWindow time.Time) (int, int, error)
}

Expand Down Expand Up @@ -119,7 +120,7 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler {
return
}

err = l.limitCounter.Increment(key, currentWindow)
err = l.limitCounter.IncrementBy(key, currentWindow, getIncrement(r.Context()))
if err != nil {
l.mu.Unlock()
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down Expand Up @@ -152,6 +153,10 @@ func (c *localCounter) Config(requestLimit int, windowLength time.Duration) {
}

func (c *localCounter) Increment(key string, currentWindow time.Time) error {
return c.IncrementBy(key, currentWindow, 1)
}

func (c *localCounter) IncrementBy(key string, currentWindow time.Time, amount int) error {
c.evict()

c.mu.Lock()
Expand All @@ -164,7 +169,7 @@ func (c *localCounter) Increment(key string, currentWindow time.Time) error {
v = &count{}
c.counters[hkey] = v
}
v.value += 1
v.value += amount
v.updatedAt = time.Now()

return nil
Expand Down
39 changes: 39 additions & 0 deletions limiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,45 @@ func TestLimit(t *testing.T) {
}
}

func TestWithIncrement(t *testing.T) {
type test struct {
name string
requestsLimit int
windowLength time.Duration
respCodes []int
}
tests := []test{
{
name: "no-block",
requestsLimit: 3,
windowLength: 4 * time.Second,
respCodes: []int{200, 200, 429},
},
{
name: "block",
requestsLimit: 3,
windowLength: 2 * time.Second,
respCodes: []int{200, 200, 429, 429},
},
}
for i, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
router := httprate.LimitAll(tt.requestsLimit, tt.windowLength)(h)

for _, code := range tt.respCodes {
req := httptest.NewRequest("GET", "/", nil)
req = req.WithContext(httprate.WithIncrement(req.Context(), 2))
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
if respCode := recorder.Result().StatusCode; respCode != code {
t.Errorf("resp.StatusCode(%v) = %v, want %v", i, respCode, code)
}
}
})
}
}

func TestLimitHandler(t *testing.T) {
type test struct {
name string
Expand Down

0 comments on commit 3327e65

Please sign in to comment.