Skip to content

Commit

Permalink
Export RateLimiter type (#43)
Browse files Browse the repository at this point in the history
So users pass *http.RateLimiter (or save in their server struct) and use
the new .OnLimit() feature from #42.
  • Loading branch information
VojtechVitek authored Aug 23, 2024
1 parent 80029e2 commit c4c778c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
14 changes: 7 additions & 7 deletions httprate.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ func Limit(requestLimit int, windowLength time.Duration, options ...Option) func
}

type KeyFunc func(r *http.Request) (string, error)
type Option func(rl *rateLimiter)
type Option func(rl *RateLimiter)

// Set custom response headers. If empty, the header is omitted.
type ResponseHeaders struct {
Expand Down Expand Up @@ -72,7 +72,7 @@ func KeyByEndpoint(r *http.Request) (string, error) {
}

func WithKeyFuncs(keyFuncs ...KeyFunc) Option {
return func(rl *rateLimiter) {
return func(rl *RateLimiter) {
if len(keyFuncs) > 0 {
rl.keyFn = composedKeyFunc(keyFuncs...)
}
Expand All @@ -88,31 +88,31 @@ func WithKeyByRealIP() Option {
}

func WithLimitHandler(h http.HandlerFunc) Option {
return func(rl *rateLimiter) {
return func(rl *RateLimiter) {
rl.onRateLimited = h
}
}

func WithErrorHandler(h func(http.ResponseWriter, *http.Request, error)) Option {
return func(rl *rateLimiter) {
return func(rl *RateLimiter) {
rl.onError = h
}
}

func WithLimitCounter(c LimitCounter) Option {
return func(rl *rateLimiter) {
return func(rl *RateLimiter) {
rl.limitCounter = c
}
}

func WithResponseHeaders(headers ResponseHeaders) Option {
return func(rl *rateLimiter) {
return func(rl *RateLimiter) {
rl.headers = headers
}
}

func WithNoop() Option {
return func(rl *rateLimiter) {}
return func(rl *RateLimiter) {}
}

func composedKeyFunc(keyFuncs ...KeyFunc) KeyFunc {
Expand Down
16 changes: 8 additions & 8 deletions limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ type LimitCounter interface {
Get(key string, currentWindow, previousWindow time.Time) (int, int, error)
}

func NewRateLimiter(requestLimit int, windowLength time.Duration, options ...Option) *rateLimiter {
rl := &rateLimiter{
func NewRateLimiter(requestLimit int, windowLength time.Duration, options ...Option) *RateLimiter {
rl := &RateLimiter{
requestLimit: requestLimit,
windowLength: windowLength,
headers: ResponseHeaders{
Expand Down Expand Up @@ -55,7 +55,7 @@ func NewRateLimiter(requestLimit int, windowLength time.Duration, options ...Opt
return rl
}

type rateLimiter struct {
type RateLimiter struct {
requestLimit int
windowLength time.Duration
keyFn KeyFunc
Expand All @@ -70,7 +70,7 @@ type rateLimiter struct {
// and automatically sends HTTP response. The caller should halt further request processing.
// If the limit is not reached, it increments the request count and returns false, allowing
// the request to proceed.
func (l *rateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string) bool {
func (l *RateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string) bool {
currentWindow := time.Now().UTC().Truncate(l.windowLength)
ctx := r.Context()

Expand Down Expand Up @@ -116,15 +116,15 @@ func (l *rateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string
return false
}

func (l *rateLimiter) Counter() LimitCounter {
func (l *RateLimiter) Counter() LimitCounter {
return l.limitCounter
}

func (l *rateLimiter) Status(key string) (bool, float64, error) {
func (l *RateLimiter) Status(key string) (bool, float64, error) {
return l.calculateRate(key, l.requestLimit)
}

func (l *rateLimiter) Handler(next http.Handler) http.Handler {
func (l *RateLimiter) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
key, err := l.keyFn(r)
if err != nil {
Expand All @@ -140,7 +140,7 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler {
})
}

func (l *rateLimiter) calculateRate(key string, requestLimit int) (bool, float64, error) {
func (l *RateLimiter) calculateRate(key string, requestLimit int) (bool, float64, error) {
now := time.Now().UTC()
currentWindow := now.Truncate(l.windowLength)
previousWindow := currentWindow.Add(-l.windowLength)
Expand Down

0 comments on commit c4c778c

Please sign in to comment.