Skip to content

Commit

Permalink
We received a vulnerability disclosure due to how we pick a remote IP…
Browse files Browse the repository at this point in the history
… address. (#99)

* We received a vulnerability disclosure due to how we pick a remote IP address.

Disclosure URL: https://gist.github.com/adam-p/4b777de4bda0027f4c3daa45618adcdc

This is an attempt to address the situation.

1. We no longer configure SetIPLookups on default.

2. We address the two different SetIPLookups confusion in two different place by removing both of them.

3. We add a new, explicit way, for user to define how IP address should be picked up.

Tests are all updated to use the new method of picking IP address.

This will be a backward incompatible change so version number has to be bumped to 7.

* Make golint happy.

* Update documentation.

* We don’t need the ability to pick which header to use.

* Fix tests.

---------

Co-authored-by: Didip Kerabat <didipkerabat@didip-personal-mbp.local>
  • Loading branch information
didip and Didip Kerabat authored Oct 9, 2024
1 parent 95418ad commit f934686
Show file tree
Hide file tree
Showing 9 changed files with 242 additions and 156 deletions.
45 changes: 38 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ This is a generic middleware to rate-limit HTTP requests.

**v7.x.x:** Replaced `time/rate` with `embedded time/rate` so that we can support more rate limit headers.

**v8.x.x:** Address `RemoteIP` vulnerability concern by replacing it with `RemoteIPFromIPLookup`, an explicit way to pick the IP address.


## Five Minute Tutorial

```go
Expand All @@ -34,6 +37,7 @@ import (
"net/http"

"github.com/didip/tollbooth/v7"
"github.com/didip/tollbooth/v7/limiter"
)

func HelloHandler(w http.ResponseWriter, req *http.Request) {
Expand All @@ -42,7 +46,15 @@ func HelloHandler(w http.ResponseWriter, req *http.Request) {

func main() {
// Create a request limiter per handler.
http.Handle("/", tollbooth.LimitFuncHandler(tollbooth.NewLimiter(1, nil), HelloHandler))
lmt := tollbooth.NewLimiter(1, nil)

// New in version >= 8, you must explicitly define how to pick the IP address.
lmt.SetIPLookup(limiter.IPLookup{
Name: "X-Real-IP",
IndexFromRight: 0,
})

http.Handle("/", tollbooth.LimitFuncHandler(lmt, HelloHandler))
http.ListenAndServe(":12345", nil)
}
```
Expand All @@ -66,10 +78,24 @@ func main() {
// every token bucket in it will expire 1 hour after it was initially set.
lmt = tollbooth.NewLimiter(1, &limiter.ExpirableOptions{DefaultExpirationTTL: time.Hour})

// Configure list of places to look for IP address.
// By default it's: "RemoteAddr", "X-Forwarded-For", "X-Real-IP"
// If your application is behind a proxy, set "X-Forwarded-For" first.
lmt.SetIPLookups([]string{"RemoteAddr", "X-Forwarded-For", "X-Real-IP"})
// New in version >= 8, you must explicitly define how to pick the IP address.
// If IP address cannot be found, rate limiter will not be activated.
lmt.SetIPLookup(limiter.IPLookup{
// The name of lookup method.
// Possible options are: RemoteAddr, X-Forwarded-For, X-Real-IP, CF-Connecting-IP
// All other headers are considered unknown and will be ignored.
Name: "X-Real-IP",

// The index position to pick the ip address from a comma separated list.
// The index goes from right to left.
//
// When there are multiple of the same headers,
// we will concat them together in the order of first to last seen.
// And then we pick the IP using this index position.
IndexFromRight: 0,
})

// In version >= 8, lmt.SetIPLookups and lmt.GetIPLookups are removed.

// Limit only GET and POST requests.
lmt.SetMethods([]string{"GET", "POST"})
Expand All @@ -89,8 +115,7 @@ func main() {
lmt.RemoveHeaderEntries("X-Access-Token", []string{"limitless-token"})

// By the way, the setters are chainable. Example:
lmt.SetIPLookups([]string{"RemoteAddr", "X-Forwarded-For", "X-Real-IP"}).
SetMethods([]string{"GET", "POST"}).
lmt.SetMethods([]string{"GET", "POST"}).
SetBasicAuthUsers([]string{"sansa"}).
SetBasicAuthUsers([]string{"tyrion"})
```
Expand Down Expand Up @@ -137,6 +162,12 @@ func main() {
```go
lmt := tollbooth.NewLimiter(1, nil)
// New in version >= 8, you must explicitly define how to pick the IP address.
lmt.SetIPLookup(limiter.IPLookup{
Name: "X-Forwarded-For",
IndexFromRight: 0,
})
// Set a custom message.
lmt.SetMessage("You have reached maximum request limit.")
Expand Down
53 changes: 26 additions & 27 deletions libstring/libstring.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"net"
"net/http"
"strings"

"github.com/didip/tollbooth/v7/limiter"
)

// StringInSlice finds needle in a slice of strings.
Expand All @@ -17,38 +19,35 @@ func StringInSlice(sliceString []string, needle string) bool {
return false
}

// RemoteIP finds IP Address given http.Request struct.
func RemoteIP(ipLookups []string, forwardedForIndexFromBehind int, r *http.Request) string {
realIP := r.Header.Get("X-Real-IP")
forwardedFor := r.Header.Get("X-Forwarded-For")

for _, lookup := range ipLookups {
if lookup == "RemoteAddr" {
// 1. Cover the basic use cases for both ipv4 and ipv6
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
// 2. Upon error, just return the remote addr.
return r.RemoteAddr
}
return ip
// RemoteIPFromIPLookup picks an ip address explicitly from limiter.IPLookup criteria.
// This function is intended to replace RemoteIP function.
func RemoteIPFromIPLookup(ipLookup limiter.IPLookup, r *http.Request) string {
switch ipLookup.Name {
case "RemoteAddr":
// 1. Cover the basic use cases for both ipv4 and ipv6
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
// 2. Upon error, just return the remote addr.
return r.RemoteAddr
}
if lookup == "X-Forwarded-For" && forwardedFor != "" {
// X-Forwarded-For is potentially a list of addresses separated with ","
parts := strings.Split(forwardedFor, ",")
for i, p := range parts {
parts[i] = strings.TrimSpace(p)
}
return ip

partIndex := len(parts) - 1 - forwardedForIndexFromBehind
if partIndex < 0 {
partIndex = 0
}
case "X-Forwarded-For", "X-Real-IP", "CF-Connecting-IP":
ipAddrListCommaSeparated := r.Header.Values(ipLookup.Name)

return parts[partIndex]
ipAddrCommaSeparated := strings.Join(ipAddrListCommaSeparated, ",")

ips := strings.Split(ipAddrCommaSeparated, ",")
for i, p := range ips {
ips[i] = strings.TrimSpace(p)
}
if lookup == "X-Real-IP" && realIP != "" {
return realIP

ipIndex := len(ips) - 1 - ipLookup.IndexFromRight
if ipIndex < 0 {
ipIndex = 0
}

return ips[ipIndex]
}

return ""
Expand Down
103 changes: 52 additions & 51 deletions libstring/libstring_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"net/http"
"strings"
"testing"

"github.com/didip/tollbooth/v7/limiter"
)

func TestStringInSlice(t *testing.T) {
Expand All @@ -12,28 +14,7 @@ func TestStringInSlice(t *testing.T) {
}
}

func TestRemoteIPDefault(t *testing.T) {
ipLookups := []string{"RemoteAddr", "X-Real-IP"}
ipv6 := "2601:7:1c82:4097:59a0:a80b:2841:b8c8"

request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!"))
if err != nil {
t.Errorf("Unable to create new HTTP request. Error: %v", err)
}

request.Header.Set("X-Real-IP", ipv6)

ip := RemoteIP(ipLookups, 0, request)
if ip != request.RemoteAddr {
t.Errorf("Did not get the right IP. IP: %v", ip)
}
if ip == ipv6 {
t.Errorf("X-Real-IP should have been skipped. IP: %v", ip)
}
}

func TestRemoteIPForwardedFor(t *testing.T) {
ipLookups := []string{"X-Forwarded-For", "X-Real-IP", "RemoteAddr"}
ipv6 := "2601:7:1c82:4097:59a0:a80b:2841:b8c8"

request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!"))
Expand All @@ -44,7 +25,11 @@ func TestRemoteIPForwardedFor(t *testing.T) {
request.Header.Set("X-Forwarded-For", "10.10.10.10")
request.Header.Set("X-Real-IP", ipv6)

ip := RemoteIP(ipLookups, 0, request)
ip := RemoteIPFromIPLookup(limiter.IPLookup{
Name: "X-Forwarded-For",
IndexFromRight: 0,
}, request)

if ip != "10.10.10.10" {
t.Errorf("Did not get the right IP. IP: %v", ip)
}
Expand All @@ -54,7 +39,6 @@ func TestRemoteIPForwardedFor(t *testing.T) {
}

func TestRemoteIPRealIP(t *testing.T) {
ipLookups := []string{"X-Real-IP", "X-Forwarded-For", "RemoteAddr"}
ipv6 := "2601:7:1c82:4097:59a0:a80b:2841:b8c8"

request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!"))
Expand All @@ -65,7 +49,11 @@ func TestRemoteIPRealIP(t *testing.T) {
request.Header.Set("X-Forwarded-For", "10.10.10.10")
request.Header.Set("X-Real-IP", ipv6)

ip := RemoteIP(ipLookups, 0, request)
ip := RemoteIPFromIPLookup(limiter.IPLookup{
Name: "X-Real-IP",
IndexFromRight: 0,
}, request)

if ip != ipv6 {
t.Errorf("Did not get the right IP. IP: %v", ip)
}
Expand All @@ -74,53 +62,64 @@ func TestRemoteIPRealIP(t *testing.T) {
}
}

func TestRemoteIPMultipleForwardedFor(t *testing.T) {
ipLookups := []string{"X-Forwarded-For", "X-Real-IP", "RemoteAddr"}
ipv6 := "2601:7:1c82:4097:59a0:a80b:2841:b8c8"

func TestRemoteIPMultipleForwardedForIPAddresses(t *testing.T) {
request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!"))
if err != nil {
t.Errorf("Unable to create new HTTP request. Error: %v", err)
}

request.Header.Set("X-Real-IP", ipv6)

// Missing X-Forwarded-For should not break things
ip := RemoteIP(ipLookups, 0, request)
if ip != ipv6 {
t.Errorf("X-Real-IP should have been chosen because X-Forwarded-For is missing. IP: %v", ip)
}

request.Header.Set("X-Forwarded-For", "10.10.10.10,10.10.10.11")

ip := RemoteIPFromIPLookup(limiter.IPLookup{
Name: "X-Forwarded-For",
IndexFromRight: 0,
}, request)

// Should get the last one
ip = RemoteIP(ipLookups, 0, request)
if ip != "10.10.10.11" {
t.Errorf("Did not get the right IP. IP: %v", ip)
}
if ip == ipv6 {
t.Errorf("X-Real-IP should have been skipped. IP: %v", ip)
}

ip = RemoteIPFromIPLookup(limiter.IPLookup{
Name: "X-Forwarded-For",
IndexFromRight: 1,
}, request)

// Should get the 2nd from last
ip = RemoteIP(ipLookups, 1, request)
if ip != "10.10.10.10" {
t.Errorf("Did not get the right IP. IP: %v", ip)
}
if ip == ipv6 {
t.Errorf("X-Real-IP should have been skipped. IP: %v", ip)
}

// What about index out of bound? RemoteIP should simply choose index 0.
ip = RemoteIP(ipLookups, 2, request)
ip = RemoteIPFromIPLookup(limiter.IPLookup{
Name: "X-Forwarded-For",
IndexFromRight: 2,
}, request)

if ip != "10.10.10.10" {
t.Errorf("Did not get the right IP. IP: %v", ip)
}
if ip == ipv6 {
t.Errorf("X-Real-IP should have been skipped. IP: %v", ip)
}
}

func TestRemoteIPMultipleForwardedForHeaders(t *testing.T) {
request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!"))
if err != nil {
t.Errorf("Unable to create new HTTP request. Error: %v", err)
}

request.Header.Add("X-Forwarded-For", "8.8.8.8,8.8.4.4")
request.Header.Add("X-Forwarded-For", "10.10.10.10,10.10.10.11")

ip := RemoteIPFromIPLookup(limiter.IPLookup{
Name: "X-Forwarded-For",
IndexFromRight: 0,
}, request)

// Should get the last header and the last IP
if ip != "10.10.10.11" {
t.Errorf("Did not get the right IP. IP: %v", ip)
}
}
func TestCanonicalizeIP(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -169,10 +168,12 @@ func TestCanonicalizeIP(t *testing.T) {
},
}
for _, tt := range tests {
tt := tt
ip := tt.ip
want := tt.want

t.Run(tt.name, func(t *testing.T) {
if got := CanonicalizeIP(tt.ip); got != tt.want {
t.Errorf("CanonicalizeIP() = %v, want %v", got, tt.want)
if got := CanonicalizeIP(ip); got != want {
t.Errorf("CanonicalizeIP() = %v, want %v", got, want)
}
})
}
Expand Down
Loading

0 comments on commit f934686

Please sign in to comment.