Skip to content

Commit

Permalink
Merge pull request #74 from unmarshal/read_header_timeout
Browse files Browse the repository at this point in the history
Add support for ReadHeaderTimeout
  • Loading branch information
pires authored Apr 22, 2021
2 parents 7f48261 + cdc6386 commit 3aa7ea9
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 3 deletions.
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,38 @@ func main() {
}
```

### HTTP Server
```go
package main

import (
"net"
"net/http"
"time"

"github.com/pires/go-proxyproto"
)

func main() {
server := http.Server{
Addr: ":8080",
}

ln, err := net.Listen("tcp", server.Addr)
if err != nil {
panic(err)
}

proxyListener := &proxyproto.Listener{
Listener: ln,
ReadHeaderTimeout: 10 * time.Second,
}
defer proxyListener.Close()

server.Serve(proxyListener)
}
```

## Special notes

### AWS
Expand Down
48 changes: 48 additions & 0 deletions examples/client/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package main

import (
"io"
"log"
"net"

proxyproto "github.com/pires/go-proxyproto"
)

func chkErr(err error) {
if err != nil {
log.Fatalf("Error: %s", err.Error())
}
}

func main() {
// Dial some proxy listener e.g. https://github.com/mailgun/proxyproto
target, err := net.ResolveTCPAddr("tcp", "127.0.0.1:9876")
chkErr(err)

conn, err := net.DialTCP("tcp", nil, target)
chkErr(err)

defer conn.Close()

// Create a proxyprotocol header or use HeaderProxyFromAddrs() if you
// have two conn's
header := &proxyproto.Header{
Version: 1,
Command: proxyproto.PROXY,
TransportProtocol: proxyproto.TCPv4,
SourceAddr: &net.TCPAddr{
IP: net.ParseIP("10.1.1.1"),
Port: 1000,
},
DestinationAddr: &net.TCPAddr{
IP: net.ParseIP("20.2.2.2"),
Port: 2000,
},
}
// After the connection was created write the proxy headers first
_, err = header.WriteTo(conn)
chkErr(err)
// Then your data... e.g.:
_, err = io.WriteString(conn, "HELO")
chkErr(err)
}
39 changes: 39 additions & 0 deletions examples/httpserver/httpserver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package main

import (
"log"
"net"
"net/http"
"time"

"github.com/pires/go-proxyproto"
)

// TODO: add httpclient example

func main() {
server := http.Server{
Addr: ":8080",
ConnState: func(c net.Conn, s http.ConnState) {
if s == http.StateNew {
log.Printf("[ConnState] %s -> %s", c.LocalAddr().String(), c.RemoteAddr().String())
}
},
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("[Handler] remote ip %q", r.RemoteAddr)
}),
}

ln, err := net.Listen("tcp", server.Addr)
if err != nil {
panic(err)
}

proxyListener := &proxyproto.Listener{
Listener: ln,
ReadHeaderTimeout: 10 * time.Second,
}
defer proxyListener.Close()

server.Serve(proxyListener)
}
36 changes: 36 additions & 0 deletions examples/server/server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package main

import (
"log"
"net"

proxyproto "github.com/pires/go-proxyproto"
)

func main() {
// Create a listener
addr := "localhost:9876"
list, err := net.Listen("tcp", addr)
if err != nil {
log.Fatalf("couldn't listen to %q: %q\n", addr, err.Error())
}

// Wrap listener in a proxyproto listener
proxyListener := &proxyproto.Listener{Listener: list}
defer proxyListener.Close()

// Wait for a connection and accept it
conn, err := proxyListener.Accept()
defer conn.Close()

// Print connection details
if conn.LocalAddr() == nil {
log.Fatal("couldn't retrieve local address")
}
log.Printf("local address: %q", conn.LocalAddr().String())

if conn.RemoteAddr() == nil {
log.Fatal("couldn't retrieve remote address")
}
log.Printf("remote address: %q", conn.RemoteAddr().String())
}
11 changes: 8 additions & 3 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ import (
// If the connection is using the protocol, the RemoteAddr() will return
// the correct client address.
type Listener struct {
Listener net.Listener
Policy PolicyFunc
ValidateHeader Validator
Listener net.Listener
Policy PolicyFunc
ValidateHeader Validator
ReadHeaderTimeout time.Duration
}

// Conn is used to wrap and underlying connection which
Expand Down Expand Up @@ -52,6 +53,10 @@ func (p *Listener) Accept() (net.Conn, error) {
return nil, err
}

if d := p.ReadHeaderTimeout; d != 0 {
conn.SetReadDeadline(time.Now().Add(d))
}

proxyHeaderPolicy := USE
if p.Policy != nil {
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
Expand Down
38 changes: 38 additions & 0 deletions protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ package proxyproto

import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"io/ioutil"
"net"
"testing"
"time"
)

func TestPassthrough(t *testing.T) {
Expand Down Expand Up @@ -61,6 +63,42 @@ func TestPassthrough(t *testing.T) {
}
}

func TestReadHeaderTimeout(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}

pl := &Listener{
Listener: l,
ReadHeaderTimeout: 1 * time.Millisecond,
}

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()

<-ctx.Done()
}()

conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()

// Read blocks forever if there is no ReadHeaderTimeout
recv := make([]byte, 4)
_, err = conn.Read(recv)

}

func TestParse_ipv4(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
Expand Down

0 comments on commit 3aa7ea9

Please sign in to comment.