-
Notifications
You must be signed in to change notification settings - Fork 18
/
httprate.go
168 lines (143 loc) · 3.82 KB
/
httprate.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
package httprate
import (
"net"
"net/http"
"strings"
"time"
)
func Limit(requestLimit int, windowLength time.Duration, options ...Option) func(next http.Handler) http.Handler {
return NewRateLimiter(requestLimit, windowLength, options...).Handler
}
type KeyFunc func(r *http.Request) (string, error)
type Option func(rl *RateLimiter)
// Set custom response headers. If empty, the header is omitted.
type ResponseHeaders struct {
Limit string // Default: X-RateLimit-Limit
Remaining string // Default: X-RateLimit-Remaining
Increment string // Default: X-RateLimit-Increment
Reset string // Default: X-RateLimit-Reset
RetryAfter string // Default: Retry-After
}
func LimitAll(requestLimit int, windowLength time.Duration) func(next http.Handler) http.Handler {
return Limit(requestLimit, windowLength)
}
func LimitByIP(requestLimit int, windowLength time.Duration) func(next http.Handler) http.Handler {
return Limit(requestLimit, windowLength, WithKeyFuncs(KeyByIP))
}
func LimitByRealIP(requestLimit int, windowLength time.Duration) func(next http.Handler) http.Handler {
return Limit(requestLimit, windowLength, WithKeyFuncs(KeyByRealIP))
}
func Key(key string) func(r *http.Request) (string, error) {
return func(r *http.Request) (string, error) {
return key, nil
}
}
func KeyByIP(r *http.Request) (string, error) {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
ip = r.RemoteAddr
}
return canonicalizeIP(ip), nil
}
func KeyByRealIP(r *http.Request) (string, error) {
var ip string
if tcip := r.Header.Get("True-Client-IP"); tcip != "" {
ip = tcip
} else if xrip := r.Header.Get("X-Real-IP"); xrip != "" {
ip = xrip
} else if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
i := strings.Index(xff, ", ")
if i == -1 {
i = len(xff)
}
ip = xff[:i]
} else {
var err error
ip, _, err = net.SplitHostPort(r.RemoteAddr)
if err != nil {
ip = r.RemoteAddr
}
}
return canonicalizeIP(ip), nil
}
func KeyByEndpoint(r *http.Request) (string, error) {
return r.URL.Path, nil
}
func WithKeyFuncs(keyFuncs ...KeyFunc) Option {
return func(rl *RateLimiter) {
if len(keyFuncs) > 0 {
rl.keyFn = composedKeyFunc(keyFuncs...)
}
}
}
func WithKeyByIP() Option {
return WithKeyFuncs(KeyByIP)
}
func WithKeyByRealIP() Option {
return WithKeyFuncs(KeyByRealIP)
}
func WithLimitHandler(h http.HandlerFunc) Option {
return func(rl *RateLimiter) {
rl.onRateLimited = h
}
}
func WithErrorHandler(h func(http.ResponseWriter, *http.Request, error)) Option {
return func(rl *RateLimiter) {
rl.onError = h
}
}
func WithLimitCounter(c LimitCounter) Option {
return func(rl *RateLimiter) {
rl.limitCounter = c
}
}
func WithResponseHeaders(headers ResponseHeaders) Option {
return func(rl *RateLimiter) {
rl.headers = headers
}
}
func WithNoop() Option {
return func(rl *RateLimiter) {}
}
func composedKeyFunc(keyFuncs ...KeyFunc) KeyFunc {
return func(r *http.Request) (string, error) {
var key strings.Builder
for i := 0; i < len(keyFuncs); i++ {
k, err := keyFuncs[i](r)
if err != nil {
return "", err
}
key.WriteString(k)
key.WriteRune(':')
}
return key.String(), nil
}
}
// canonicalizeIP returns a form of ip suitable for comparison to other IPs.
// For IPv4 addresses, this is simply the whole string.
// For IPv6 addresses, this is the /64 prefix.
func canonicalizeIP(ip string) string {
isIPv6 := false
// This is how net.ParseIP decides if an address is IPv6
// https://cs.opensource.google/go/go/+/refs/tags/go1.17.7:src/net/ip.go;l=704
for i := 0; !isIPv6 && i < len(ip); i++ {
switch ip[i] {
case '.':
// IPv4
return ip
case ':':
// IPv6
isIPv6 = true
break
}
}
if !isIPv6 {
// Not an IP address at all
return ip
}
ipv6 := net.ParseIP(ip)
if ipv6 == nil {
return ip
}
return ipv6.Mask(net.CIDRMask(64, 128)).String()
}