Skip to content

Commit

Permalink
Introduce RespondOnLimit() vs. OnLimit() (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
VojtechVitek authored Aug 23, 2024
1 parent c4c778c commit 5e681e3
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 9 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ r.Post("/login", func(w http.ResponseWriter, r *http.Request) {
}

// Rate-limit login at 5 req/min.
if loginRateLimiter.OnLimit(w, r, payload.Username) {
if loginRateLimiter.RespondOnLimit(w, r, payload.Username) {
return
}

Expand Down
2 changes: 1 addition & 1 deletion _example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func main() {
}

// Rate-limit login at 5 req/min.
if loginRateLimiter.OnLimit(w, r, payload.Username) {
if loginRateLimiter.RespondOnLimit(w, r, payload.Username) {
return
}

Expand Down
23 changes: 17 additions & 6 deletions limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ type RateLimiter struct {
mu sync.Mutex
}

// OnLimit checks the rate limit for the given key. If the limit is reached, it returns true
// and automatically sends HTTP response. The caller should halt further request processing.
// If the limit is not reached, it increments the request count and returns false, allowing
// the request to proceed.
// OnLimit checks the rate limit for the given key and updates the response headers accordingly.
// If the limit is reached, it returns true, indicating that the request should be halted. Otherwise,
// it increments the request count and returns false. This method does not send an HTTP response,
// so the caller must handle the response themselves or use the RespondOnLimit() method instead.
func (l *RateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string) bool {
currentWindow := time.Now().UTC().Truncate(l.windowLength)
ctx := r.Context()
Expand Down Expand Up @@ -100,7 +100,6 @@ func (l *RateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string

l.mu.Unlock()
setHeader(w, l.headers.RetryAfter, fmt.Sprintf("%d", int(l.windowLength.Seconds()))) // RFC 6585
l.onRateLimited(w, r)
return true
}

Expand All @@ -116,6 +115,18 @@ func (l *RateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string
return false
}

// RespondOnLimit checks the rate limit for the given key and updates the response headers accordingly.
// If the limit is reached, it automatically sends an HTTP response and returns true, signaling the
// caller to halt further request processing. If the limit is not reached, it increments the request
// count and returns false, allowing the request to proceed.
func (l *RateLimiter) RespondOnLimit(w http.ResponseWriter, r *http.Request, key string) bool {
onLimit := l.OnLimit(w, r, key)
if onLimit {
l.onRateLimited(w, r)
}
return onLimit
}

func (l *RateLimiter) Counter() LimitCounter {
return l.limitCounter
}
Expand All @@ -132,7 +143,7 @@ func (l *RateLimiter) Handler(next http.Handler) http.Handler {
return
}

if l.OnLimit(w, r, key) {
if l.RespondOnLimit(w, r, key) {
return
}

Expand Down
2 changes: 1 addition & 1 deletion limiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ func TestRateLimitPayload(t *testing.T) {
}

// Rate-limit login at 5 req/min.
if loginRateLimiter.OnLimit(w, r, payload.Username) {
if loginRateLimiter.RespondOnLimit(w, r, payload.Username) {
return
}

Expand Down

0 comments on commit 5e681e3

Please sign in to comment.