diff --git a/README.md b/README.md index 1aedea5..982707c 100644 --- a/README.md +++ b/README.md @@ -119,6 +119,38 @@ func main() { } ``` +### HTTP Server +```go +package main + +import ( + "net" + "net/http" + "time" + + "github.com/pires/go-proxyproto" +) + +func main() { + server := http.Server{ + Addr: ":8080", + } + + ln, err := net.Listen("tcp", server.Addr) + if err != nil { + panic(err) + } + + proxyListener := &proxyproto.Listener{ + Listener: ln, + ReadHeaderTimeout: 10 * time.Second, + } + defer proxyListener.Close() + + server.Serve(proxyListener) +} +``` + ## Special notes ### AWS diff --git a/examples/client/client.go b/examples/client/client.go new file mode 100644 index 0000000..7c795fa --- /dev/null +++ b/examples/client/client.go @@ -0,0 +1,48 @@ +package main + +import ( + "io" + "log" + "net" + + proxyproto "github.com/pires/go-proxyproto" +) + +func chkErr(err error) { + if err != nil { + log.Fatalf("Error: %s", err.Error()) + } +} + +func main() { + // Dial some proxy listener e.g. https://github.com/mailgun/proxyproto + target, err := net.ResolveTCPAddr("tcp", "127.0.0.1:9876") + chkErr(err) + + conn, err := net.DialTCP("tcp", nil, target) + chkErr(err) + + defer conn.Close() + + // Create a proxyprotocol header or use HeaderProxyFromAddrs() if you + // have two conn's + header := &proxyproto.Header{ + Version: 1, + Command: proxyproto.PROXY, + TransportProtocol: proxyproto.TCPv4, + SourceAddr: &net.TCPAddr{ + IP: net.ParseIP("10.1.1.1"), + Port: 1000, + }, + DestinationAddr: &net.TCPAddr{ + IP: net.ParseIP("20.2.2.2"), + Port: 2000, + }, + } + // After the connection was created write the proxy headers first + _, err = header.WriteTo(conn) + chkErr(err) + // Then your data... e.g.: + _, err = io.WriteString(conn, "HELO") + chkErr(err) +} diff --git a/examples/httpserver/httpserver.go b/examples/httpserver/httpserver.go new file mode 100644 index 0000000..b04f2c7 --- /dev/null +++ b/examples/httpserver/httpserver.go @@ -0,0 +1,39 @@ +package main + +import ( + "log" + "net" + "net/http" + "time" + + "github.com/pires/go-proxyproto" +) + +// TODO: add httpclient example + +func main() { + server := http.Server{ + Addr: ":8080", + ConnState: func(c net.Conn, s http.ConnState) { + if s == http.StateNew { + log.Printf("[ConnState] %s -> %s", c.LocalAddr().String(), c.RemoteAddr().String()) + } + }, + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf("[Handler] remote ip %q", r.RemoteAddr) + }), + } + + ln, err := net.Listen("tcp", server.Addr) + if err != nil { + panic(err) + } + + proxyListener := &proxyproto.Listener{ + Listener: ln, + ReadHeaderTimeout: 10 * time.Second, + } + defer proxyListener.Close() + + server.Serve(proxyListener) +} diff --git a/examples/server/server.go b/examples/server/server.go new file mode 100644 index 0000000..286dc2c --- /dev/null +++ b/examples/server/server.go @@ -0,0 +1,36 @@ +package main + +import ( + "log" + "net" + + proxyproto "github.com/pires/go-proxyproto" +) + +func main() { + // Create a listener + addr := "localhost:9876" + list, err := net.Listen("tcp", addr) + if err != nil { + log.Fatalf("couldn't listen to %q: %q\n", addr, err.Error()) + } + + // Wrap listener in a proxyproto listener + proxyListener := &proxyproto.Listener{Listener: list} + defer proxyListener.Close() + + // Wait for a connection and accept it + conn, err := proxyListener.Accept() + defer conn.Close() + + // Print connection details + if conn.LocalAddr() == nil { + log.Fatal("couldn't retrieve local address") + } + log.Printf("local address: %q", conn.LocalAddr().String()) + + if conn.RemoteAddr() == nil { + log.Fatal("couldn't retrieve remote address") + } + log.Printf("remote address: %q", conn.RemoteAddr().String()) +} diff --git a/protocol.go b/protocol.go index d044335..0f493ba 100644 --- a/protocol.go +++ b/protocol.go @@ -13,9 +13,10 @@ import ( // If the connection is using the protocol, the RemoteAddr() will return // the correct client address. type Listener struct { - Listener net.Listener - Policy PolicyFunc - ValidateHeader Validator + Listener net.Listener + Policy PolicyFunc + ValidateHeader Validator + ReadHeaderTimeout time.Duration } // Conn is used to wrap and underlying connection which @@ -52,6 +53,10 @@ func (p *Listener) Accept() (net.Conn, error) { return nil, err } + if d := p.ReadHeaderTimeout; d != 0 { + conn.SetReadDeadline(time.Now().Add(d)) + } + proxyHeaderPolicy := USE if p.Policy != nil { proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr()) diff --git a/protocol_test.go b/protocol_test.go index e4c9d3d..3f06815 100644 --- a/protocol_test.go +++ b/protocol_test.go @@ -6,6 +6,7 @@ package proxyproto import ( "bytes" + "context" "crypto/tls" "crypto/x509" "fmt" @@ -13,6 +14,7 @@ import ( "io/ioutil" "net" "testing" + "time" ) func TestPassthrough(t *testing.T) { @@ -61,6 +63,42 @@ func TestPassthrough(t *testing.T) { } } +func TestReadHeaderTimeout(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } + + pl := &Listener{ + Listener: l, + ReadHeaderTimeout: 1 * time.Millisecond, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + conn, err := net.Dial("tcp", pl.Addr().String()) + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + <-ctx.Done() + }() + + conn, err := pl.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + // Read blocks forever if there is no ReadHeaderTimeout + recv := make([]byte, 4) + _, err = conn.Read(recv) + +} + func TestParse_ipv4(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil {