diff --git a/p2p/transport/tcp/tcp.go b/p2p/transport/tcp/tcp.go index f21f134248..5883e43f6a 100644 --- a/p2p/transport/tcp/tcp.go +++ b/p2p/transport/tcp/tcp.go @@ -269,7 +269,7 @@ func (t *TcpTransport) Listen(laddr ma.Multiaddr) (transport.Listener, error) { if t.sharedTcp == nil { list, err = t.unsharedMAListen(laddr) } else { - list, err = t.sharedTcp.DemultiplexedListen(laddr, tcpreuse.MultistreamSelect) + list, err = t.sharedTcp.DemultiplexedListen(laddr, tcpreuse.DemultiplexedConnType_MultistreamSelect) } if err != nil { return nil, err diff --git a/p2p/transport/tcpreuse/demultiplex.go b/p2p/transport/tcpreuse/demultiplex.go index f0c7604b06..7b69ee35d3 100644 --- a/p2p/transport/tcpreuse/demultiplex.go +++ b/p2p/transport/tcpreuse/demultiplex.go @@ -26,19 +26,19 @@ var _ peekAble = (*bufio.Reader)(nil) type DemultiplexedConnType int const ( - Unknown DemultiplexedConnType = iota - MultistreamSelect - HTTP - TLS + DemultiplexedConnType_Unknown DemultiplexedConnType = iota + DemultiplexedConnType_MultistreamSelect + DemultiplexedConnType_HTTP + DemultiplexedConnType_TLS ) func (t DemultiplexedConnType) String() string { switch t { - case MultistreamSelect: + case DemultiplexedConnType_MultistreamSelect: return "MultistreamSelect" - case HTTP: + case DemultiplexedConnType_HTTP: return "HTTP" - case TLS: + case DemultiplexedConnType_TLS: return "TLS" default: return fmt.Sprintf("Unknown(%d)", int(t)) @@ -67,15 +67,15 @@ func ConnTypeFromConn(c net.Conn) (DemultiplexedConnType, manet.Conn, error) { } if IsMultistreamSelect(s) { - return MultistreamSelect, sc, nil + return DemultiplexedConnType_MultistreamSelect, sc, nil } if IsTLS(s) { - return TLS, sc, nil + return DemultiplexedConnType_TLS, sc, nil } if IsHTTP(s) { - return HTTP, sc, nil + return DemultiplexedConnType_HTTP, sc, nil } - return Unknown, sc, nil + return DemultiplexedConnType_Unknown, sc, nil } // ReadSampleFromConn reads a sample and returns a reader which still includes the sample, so it can be kept undamaged. @@ -92,7 +92,7 @@ func ReadSampleFromConn(c net.Conn) (Sample, manet.Conn, error) { return Sample(b), mac, nil case errors.Is(err, bufio.ErrBufferFull): - // We can only peek < len(Sample{}) data. + // We can only peek < len(Sample{}) data. // fallback to sampledConn default: return Sample{}, nil, err diff --git a/p2p/transport/tcpreuse/listener.go b/p2p/transport/tcpreuse/listener.go index 1102cbd988..d7f566b5d0 100644 --- a/p2p/transport/tcpreuse/listener.go +++ b/p2p/transport/tcpreuse/listener.go @@ -33,11 +33,11 @@ func NewConnMgr(disableReuseport bool) *ConnMgr { } } -func (t *ConnMgr) maListen(laddr ma.Multiaddr) (manet.Listener, error) { +func (t *ConnMgr) maListen(listenAddr ma.Multiaddr) (manet.Listener, error) { if t.useReuseport() { - return t.reuse.Listen(laddr) + return t.reuse.Listen(listenAddr) } else { - return manet.Listen(laddr) + return manet.Listen(listenAddr) } } @@ -45,9 +45,9 @@ func (t *ConnMgr) useReuseport() bool { return !t.disableReuseport && ReuseportIsAvailable() } -func getTCPAddr(laddr ma.Multiaddr) (ma.Multiaddr, error) { +func getTCPAddr(listenAddr ma.Multiaddr) (ma.Multiaddr, error) { haveTCP := false - addr, _ := ma.SplitFunc(laddr, func(c ma.Component) bool { + addr, _ := ma.SplitFunc(listenAddr, func(c ma.Component) bool { if haveTCP { return true } @@ -57,23 +57,23 @@ func getTCPAddr(laddr ma.Multiaddr) (ma.Multiaddr, error) { return false }) if !haveTCP { - return nil, fmt.Errorf("invalid listen addr %s, need tcp address", laddr) + return nil, fmt.Errorf("invalid listen addr %s, need tcp address", listenAddr) } return addr, nil } -func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType DemultiplexedConnType) (manet.Listener, error) { +func (t *ConnMgr) DemultiplexedListen(listenAddr ma.Multiaddr, connType DemultiplexedConnType) (manet.Listener, error) { if !connType.IsKnown() { return nil, fmt.Errorf("unknown connection type: %s", connType) } - laddr, err := getTCPAddr(laddr) + listenAddr, err := getTCPAddr(listenAddr) if err != nil { return nil, err } t.mx.Lock() defer t.mx.Unlock() - ml, ok := t.listeners[laddr.String()] + ml, ok := t.listeners[listenAddr.String()] if ok { dl, err := ml.DemultiplexedListen(connType) if err != nil { @@ -82,7 +82,7 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed return dl, nil } - l, err := t.maListen(laddr) + l, err := t.maListen(listenAddr) if err != nil { return nil, err } @@ -92,7 +92,7 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed cancel() t.mx.Lock() defer t.mx.Unlock() - delete(t.listeners, laddr.String()) + delete(t.listeners, listenAddr.String()) return l.Close() } ml = &multiplexedListener{ @@ -111,7 +111,7 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed ml.wg.Add(1) go ml.run() - t.listeners[laddr.String()] = ml + t.listeners[listenAddr.String()] = ml return dl, nil } @@ -172,6 +172,7 @@ func (m *multiplexedListener) run() error { m.wg.Add(1) go func() { + defer func() { <-acceptQueue }() defer m.wg.Done() // TODO: if/how do we want to handle stalled connections and stop them from clogging up the pipeline? // Drop connection because the buffer is full diff --git a/p2p/transport/tcpreuse/listener_test.go b/p2p/transport/tcpreuse/listener_test.go index 9bf03b589d..8bfe397e0e 100644 --- a/p2p/transport/tcpreuse/listener_test.go +++ b/p2p/transport/tcpreuse/listener_test.go @@ -1,9 +1,106 @@ package tcpreuse -import "testing" +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "fmt" + "math/big" + "testing" + "time" -func TestListenerClose(t *testing.T) { - // cm := NewConnMgr(false) - // - // cm.DemultiplexedListen("/") + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" + "github.com/stretchr/testify/require" +) + +func selfSignedTLSConfig(t *testing.T) *tls.Config { + t.Helper() + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + notBefore := time.Now() + notAfter := notBefore.Add(365 * 24 * time.Hour) + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + require.NoError(t, err) + + certTemplate := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Test"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &priv.PublicKey, priv) + require.NoError(t, err) + + cert := tls.Certificate{ + Certificate: [][]byte{derBytes}, + PrivateKey: priv, + } + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + return tlsConfig +} + +func getTLSConn(t *testing.T, c manet.Conn) (manet.Conn, error) { + t.Helper() + return manet.WrapNetConn(tls.Server(c, selfSignedTLSConfig(t))) +} + +func TestListenerSingle(t *testing.T) { + listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0") + for disableReuseport := range []bool{true, false} { + t.Run(fmt.Sprintf("TLS-reuseport:%v", disableReuseport), func(t *testing.T) { + cm := NewConnMgr(false) + l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS) + require.NoError(t, err) + + go func() { + d := tls.Dialer{Config: &tls.Config{InsecureSkipVerify: true}} + for i := 0; i < 100; i++ { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + conn, err := d.DialContext(ctx, l.Addr().Network(), l.Addr().String()) + if err != nil { + t.Error("failed to dial", err, i) + return + } + buf := make([]byte, 10) + _, err = conn.Write([]byte("hello")) + if err != nil { + t.Error(err) + } + _, err = conn.Read(buf) + if err == nil { + t.Error("expected EOF got nil") + } + } + }() + for i := 0; i < 100; i++ { + c, err := l.Accept() + require.NoError(t, err) + c, err = getTLSConn(t, c) + require.NoError(t, err) + buf := make([]byte, 10) + n, err := c.Read(buf) + require.NoError(t, err) + require.Equal(t, "hello", string(buf[:n])) + c.Close() + } + }) + } } diff --git a/p2p/transport/websocket/listener.go b/p2p/transport/websocket/listener.go index 2253b6597e..0331a6561d 100644 --- a/p2p/transport/websocket/listener.go +++ b/p2p/transport/websocket/listener.go @@ -69,9 +69,9 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMg } else { var connType tcpreuse.DemultiplexedConnType if parsed.isWSS { - connType = tcpreuse.TLS + connType = tcpreuse.DemultiplexedConnType_TLS } else { - connType = tcpreuse.HTTP + connType = tcpreuse.DemultiplexedConnType_HTTP } mal, err := sharedTcp.DemultiplexedListen(parsed.restMultiaddr, connType) if err != nil {