Skip to content

Commit

Permalink
Implement rate-limiting from HTTP handler (e.g. by request payload) (#42
Browse files Browse the repository at this point in the history
)
  • Loading branch information
VojtechVitek authored Aug 23, 2024
1 parent 99b3b69 commit 80029e2
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 72 deletions.
44 changes: 36 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,36 +78,64 @@ r.Use(httprate.Limit(
))
```

### Send specific response for rate limited requests
### Rate limit by request payload
```go
// Rate-limiter for login endpoint.
loginRateLimiter := httprate.NewRateLimiter(5, time.Minute)

r.Post("/login", func(w http.ResponseWriter, r *http.Request) {
var payload struct {
Username string `json:"username"`
Password string `json:"password"`
}
err := json.NewDecoder(r.Body).Decode(&payload)
if err != nil || payload.Username == "" || payload.Password == "" {
w.WriteHeader(400)
return
}

// Rate-limit login at 5 req/min.
if loginRateLimiter.OnLimit(w, r, payload.Username) {
return
}

w.Write([]byte("login at 5 req/min\n"))
})
```

### Send specific response for rate-limited requests

The default response is `HTTP 429` with `Too Many Requests` body. You can override it with:

```go
r.Use(httprate.Limit(
10,
time.Minute,
httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, `{"error": "Rate limited. Please slow down."}`, http.StatusTooManyRequests)
http.Error(w, `{"error": "Rate-limited. Please, slow down."}`, http.StatusTooManyRequests)
}),
))
```

### Send specific response for backend errors
### Send specific response on errors

An error can be returned by:
- A custom key function provided by `httprate.WithKeyFunc(customKeyFn)`
- A custom backend provided by `httprateredis.WithRedisLimitCounter(customBackend)`
- The default local in-memory counter is guaranteed not return any errors
- Backends that fall-back to the local in-memory counter (e.g. [httprate-redis](https://github.com/go-chi/httprate-redis)) can choose not to return any errors either

```go
r.Use(httprate.Limit(
10,
time.Minute,
httprate.WithErrorHandler(func(w http.ResponseWriter, r *http.Request, err error) {
// NOTE: The local in-memory counter is guaranteed not return any errors.
// Other backends may return errors, depending on whether they have
// in-memory fallback mechanism implemented in case of network errors.

http.Error(w, fmt.Sprintf(`{"error": %q}`, err), http.StatusPreconditionRequired)
}),
httprate.WithLimitCounter(customBackend),
))
```


### Send custom response headers

```go
Expand Down
56 changes: 32 additions & 24 deletions _example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"context"
"encoding/json"
"log"
"net/http"
"time"
Expand All @@ -15,52 +16,59 @@ func main() {
r := chi.NewRouter()
r.Use(middleware.Logger)

// Rate-limit all routes at 1000 req/min by IP address.
r.Use(httprate.LimitByIP(1000, time.Minute))

r.Route("/admin", func(r chi.Router) {
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Note: this is a mock middleware to set a userID on the request context
// Note: This is a mock middleware to set a userID on the request context
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), "userID", "123")))
})
})

// Here we set a specific rate limit by ip address and userID
// Rate-limit admin routes at 10 req/s by userID.
r.Use(httprate.Limit(
10,
time.Minute,
httprate.WithKeyFuncs(httprate.KeyByIP, func(r *http.Request) (string, error) {
token := r.Context().Value("userID").(string)
10, time.Second,
httprate.WithKeyFuncs(func(r *http.Request) (string, error) {
token, _ := r.Context().Value("userID").(string)
return token, nil
}),
httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) {
// We can send custom responses for the rate limited requests, e.g. a JSON message
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte(`{"error": "Too many requests"}`))
}),
))

r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("10 req/min\n"))
w.Write([]byte("admin at 10 req/s\n"))
})
})

r.Group(func(r chi.Router) {
// Here we set another rate limit (3 req/min) for a group of handlers.
//
// Note: in practice you don't need to have so many layered rate-limiters,
// but the example here is to illustrate how to control the machinery.
r.Use(httprate.LimitByIP(3, time.Minute))
// Rate-limiter for login endpoint.
loginRateLimiter := httprate.NewRateLimiter(5, time.Minute)

r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("3 req/min\n"))
})
r.Post("/login", func(w http.ResponseWriter, r *http.Request) {
var payload struct {
Username string `json:"username"`
Password string `json:"password"`
}
err := json.NewDecoder(r.Body).Decode(&payload)
if err != nil || payload.Username == "" || payload.Password == "" {
w.WriteHeader(400)
return
}

// Rate-limit login at 5 req/min.
if loginRateLimiter.OnLimit(w, r, payload.Username) {
return
}

w.Write([]byte("login at 5 req/min\n"))
})

log.Printf("Serving at localhost:3333")
log.Println()
log.Printf("Try running:")
log.Printf("curl -v http://localhost:3333")
log.Printf("curl -v http://localhost:3333/admin")
log.Printf(`curl -v http://localhost:3333?[0-1000]`)
log.Printf(`curl -v http://localhost:3333/admin?[1-12]`)
log.Printf(`curl -v http://localhost:3333/login\?[1-8] --data '{"username":"alice","password":"***"}'`)

http.ListenAndServe(":3333", r)
}
91 changes: 51 additions & 40 deletions limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,56 @@ type rateLimiter struct {
mu sync.Mutex
}

// OnLimit checks the rate limit for the given key. If the limit is reached, it returns true
// and automatically sends HTTP response. The caller should halt further request processing.
// If the limit is not reached, it increments the request count and returns false, allowing
// the request to proceed.
func (l *rateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string) bool {
currentWindow := time.Now().UTC().Truncate(l.windowLength)
ctx := r.Context()

limit := l.requestLimit
if val := getRequestLimit(ctx); val > 0 {
limit = val
}
setHeader(w, l.headers.Limit, fmt.Sprintf("%d", limit))
setHeader(w, l.headers.Reset, fmt.Sprintf("%d", currentWindow.Add(l.windowLength).Unix()))

l.mu.Lock()
_, rateFloat, err := l.calculateRate(key, limit)
if err != nil {
l.mu.Unlock()
l.onError(w, r, err)
return true
}
rate := int(math.Round(rateFloat))

increment := getIncrement(r.Context())
if increment > 1 {
setHeader(w, l.headers.Increment, fmt.Sprintf("%d", increment))
}

if rate+increment > limit {
setHeader(w, l.headers.Remaining, fmt.Sprintf("%d", limit-rate))

l.mu.Unlock()
setHeader(w, l.headers.RetryAfter, fmt.Sprintf("%d", int(l.windowLength.Seconds()))) // RFC 6585
l.onRateLimited(w, r)
return true
}

err = l.limitCounter.IncrementBy(key, currentWindow, increment)
if err != nil {
l.mu.Unlock()
l.onError(w, r, err)
return true
}
l.mu.Unlock()

setHeader(w, l.headers.Remaining, fmt.Sprintf("%d", limit-rate-increment))
return false
}

func (l *rateLimiter) Counter() LimitCounter {
return l.limitCounter
}
Expand All @@ -82,49 +132,10 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler {
return
}

currentWindow := time.Now().UTC().Truncate(l.windowLength)
ctx := r.Context()

limit := l.requestLimit
if val := getRequestLimit(ctx); val > 0 {
limit = val
}
setHeader(w, l.headers.Limit, fmt.Sprintf("%d", limit))
setHeader(w, l.headers.Reset, fmt.Sprintf("%d", currentWindow.Add(l.windowLength).Unix()))

l.mu.Lock()
_, rateFloat, err := l.calculateRate(key, limit)
if err != nil {
l.mu.Unlock()
l.onError(w, r, err)
return
}
rate := int(math.Round(rateFloat))

increment := getIncrement(r.Context())
if increment > 1 {
setHeader(w, l.headers.Increment, fmt.Sprintf("%d", increment))
}

if rate+increment > limit {
setHeader(w, l.headers.Remaining, fmt.Sprintf("%d", limit-rate))

l.mu.Unlock()
setHeader(w, l.headers.RetryAfter, fmt.Sprintf("%d", int(l.windowLength.Seconds()))) // RFC 6585
l.onRateLimited(w, r)
if l.OnLimit(w, r, key) {
return
}

err = l.limitCounter.IncrementBy(key, currentWindow, increment)
if err != nil {
l.mu.Unlock()
l.onError(w, r, err)
return
}
l.mu.Unlock()

setHeader(w, l.headers.Remaining, fmt.Sprintf("%d", limit-rate-increment))

next.ServeHTTP(w, r)
})
}
Expand Down
57 changes: 57 additions & 0 deletions limiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package httprate_test
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -437,3 +438,59 @@ func TestOverrideRequestLimit(t *testing.T) {
}
}
}

func TestRateLimitPayload(t *testing.T) {
loginRateLimiter := httprate.NewRateLimiter(5, time.Minute)

h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var payload struct {
Username string `json:"username"`
Password string `json:"password"`
}
err := json.NewDecoder(r.Body).Decode(&payload)
if err != nil || payload.Username == "" || payload.Password == "" {
w.WriteHeader(400)
return
}

// Rate-limit login at 5 req/min.
if loginRateLimiter.OnLimit(w, r, payload.Username) {
return
}

w.Write([]byte("login at 5 req/min\n"))
})

responses := []struct {
StatusCode int
Body string
}{
{StatusCode: 200, Body: "login at 5 req/min"},
{StatusCode: 200, Body: "login at 5 req/min"},
{StatusCode: 200, Body: "login at 5 req/min"},
{StatusCode: 200, Body: "login at 5 req/min"},
{StatusCode: 200, Body: "login at 5 req/min"},
{StatusCode: 429, Body: "Too Many Requests"},
{StatusCode: 429, Body: "Too Many Requests"},
{StatusCode: 429, Body: "Too Many Requests"},
}
for i, response := range responses {
req, err := http.NewRequest("GET", "/", strings.NewReader(`{"username":"alice","password":"***"}`))
if err != nil {
t.Errorf("failed = %v", err)
}

recorder := httptest.NewRecorder()
h.ServeHTTP(recorder, req)
result := recorder.Result()
if respStatus := result.StatusCode; respStatus != response.StatusCode {
t.Errorf("resp.StatusCode(%v) = %v, want %v", i, respStatus, response.StatusCode)
}
body, _ := io.ReadAll(result.Body)
respBody := strings.TrimSuffix(string(body), "\n")

if string(respBody) != response.Body {
t.Errorf("resp.Body(%v) = %q, want %q", i, respBody, response.Body)
}
}
}

0 comments on commit 80029e2

Please sign in to comment.