-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathhttp.go
130 lines (111 loc) · 4.88 KB
/
http.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
package redis_rate_limiter
import (
"fmt"
"net/http"
"strconv"
"strings"
"time"
)
var (
_ http.Handler = &httpRateLimiterHandler{}
_ Extractor = &httpHeaderExtractor{}
stateStrings = map[State]string{
Allow: "Allow",
Deny: "Deny",
}
)
const (
rateLimitingTotalRequests = "Rate-Limiting-Total-Requests"
rateLimitingState = "Rate-Limiting-State"
rateLimitingExpiresAt = "Rate-Limiting-Expires-At"
)
// Extractor represents the way we will extract a key from an HTTP request, this could be
// a value from a header, request path, method used, user authentication information, any information that
// is available at the HTTP request that wouldn't cause side effects if it was collected (this object shouldn't
// read the body of the request).
type Extractor interface {
Extract(r *http.Request) (string, error)
}
type httpHeaderExtractor struct {
headers []string
}
// Extract extracts a collection of http headers and joins them to build the key that will be used for
// rate limiting. You should use headers that are guaranteed to be unique for a client.
func (h *httpHeaderExtractor) Extract(r *http.Request) (string, error) {
values := make([]string, 0, len(h.headers))
for _, key := range h.headers {
// if we can't find a value for the headers, give up and return an error.
if value := strings.TrimSpace(r.Header.Get(key)); value == "" {
return "", fmt.Errorf("the header %v must have a value set", key)
} else {
values = append(values, value)
}
}
return strings.Join(values, "-"), nil
}
// NewHTTPHeadersExtractor creates a new HTTP header extractor
func NewHTTPHeadersExtractor(headers ...string) Extractor {
return &httpHeaderExtractor{headers: headers}
}
// RateLimiterConfig holds the basic config we need to create a middleware http.Handler object that
// performs rate limiting before offloading the request to an actual handler.
type RateLimiterConfig struct {
Extractor Extractor
Strategy Strategy
Expiration time.Duration
MaxRequests uint64
}
// NewHTTPRateLimiterHandler wraps an existing http.Handler object performing rate limiting before
// sending the request to the wrapped handler. If any errors happen while trying to rate limit a request
// or if the request is denied, the rate limiting handler will send a response to the client and will not
// call the wrapped handler.
func NewHTTPRateLimiterHandler(originalHandler http.Handler, config *RateLimiterConfig) http.Handler {
return &httpRateLimiterHandler{
handler: originalHandler,
config: config,
}
}
type httpRateLimiterHandler struct {
handler http.Handler
config *RateLimiterConfig
}
func (h *httpRateLimiterHandler) writeRespone(writer http.ResponseWriter, status int, msg string, args ...interface{}) {
writer.Header().Set("Content-Type", "text/plain")
writer.WriteHeader(status)
if _, err := writer.Write([]byte(fmt.Sprintf(msg, args...))); err != nil {
fmt.Printf("failed to write body to HTTP request: %v", err)
}
}
// ServeHTTP performs rate limiting with the configuration it was provided and if there were not errors
// and the request was allowed it is sent to the wrapped handler. It also adds rate limiting headers that will be
// sent to the client to make it aware of what state it is in terms of rate limiting.
func (h *httpRateLimiterHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
key, err := h.config.Extractor.Extract(request)
if err != nil {
h.writeRespone(writer, http.StatusBadRequest, "failed to collect rate limiting key from request: %v", err)
return
}
result, err := h.config.Strategy.Run(request.Context(), &Request{
Key: key,
Limit: h.config.MaxRequests,
Duration: h.config.Expiration,
})
if err != nil {
h.writeRespone(writer, http.StatusInternalServerError, "failed to run rate limiting for request: %v", err)
return
}
// set the rate limiting headers both on allow or deny results so the client knows what is going on
writer.Header().Set(rateLimitingTotalRequests, strconv.FormatUint(result.TotalRequests, 10))
writer.Header().Set(rateLimitingState, stateStrings[result.State])
writer.Header().Set(rateLimitingExpiresAt, result.ExpiresAt.Format(time.RFC3339))
// when the state is Deny, just return a 429 response to the client and stop the request handling flow
if result.State == Deny {
h.writeRespone(writer, http.StatusTooManyRequests, "you have sent too many requests to this service, slow down please")
return
}
// if the request was not denied we assume it was allowed and call the wrapped handler.
// by leaving this to the end we make sure the wrapped handler is only called once and doesn't have to worry
// about any rate limiting at all (it doesn't even have to know there was rate limiting happening for this request)
// as we have already set the headers, so when the handler flushes the response the headers above will be sent.
h.handler.ServeHTTP(writer, request)
}