Skip to content

Commit

Permalink
add a tls test
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Oct 4, 2024
1 parent a2a4a6c commit 8350885
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 32 deletions.
2 changes: 1 addition & 1 deletion p2p/transport/tcp/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ func (t *TcpTransport) Listen(laddr ma.Multiaddr) (transport.Listener, error) {
if t.sharedTcp == nil {
list, err = t.unsharedMAListen(laddr)
} else {
list, err = t.sharedTcp.DemultiplexedListen(laddr, tcpreuse.MultistreamSelect)
list, err = t.sharedTcp.DemultiplexedListen(laddr, tcpreuse.DemultiplexedConnType_MultistreamSelect)
}
if err != nil {
return nil, err
Expand Down
24 changes: 12 additions & 12 deletions p2p/transport/tcpreuse/demultiplex.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,19 @@ var _ peekAble = (*bufio.Reader)(nil)
type DemultiplexedConnType int

const (
Unknown DemultiplexedConnType = iota
MultistreamSelect
HTTP
TLS
DemultiplexedConnType_Unknown DemultiplexedConnType = iota
DemultiplexedConnType_MultistreamSelect
DemultiplexedConnType_HTTP
DemultiplexedConnType_TLS
)

func (t DemultiplexedConnType) String() string {
switch t {
case MultistreamSelect:
case DemultiplexedConnType_MultistreamSelect:
return "MultistreamSelect"
case HTTP:
case DemultiplexedConnType_HTTP:
return "HTTP"
case TLS:
case DemultiplexedConnType_TLS:
return "TLS"
default:
return fmt.Sprintf("Unknown(%d)", int(t))
Expand Down Expand Up @@ -67,15 +67,15 @@ func ConnTypeFromConn(c net.Conn) (DemultiplexedConnType, manet.Conn, error) {
}

if IsMultistreamSelect(s) {
return MultistreamSelect, sc, nil
return DemultiplexedConnType_MultistreamSelect, sc, nil
}
if IsTLS(s) {
return TLS, sc, nil
return DemultiplexedConnType_TLS, sc, nil
}
if IsHTTP(s) {
return HTTP, sc, nil
return DemultiplexedConnType_HTTP, sc, nil
}
return Unknown, sc, nil
return DemultiplexedConnType_Unknown, sc, nil
}

// ReadSampleFromConn reads a sample and returns a reader which still includes the sample, so it can be kept undamaged.
Expand All @@ -92,7 +92,7 @@ func ReadSampleFromConn(c net.Conn) (Sample, manet.Conn, error) {

return Sample(b), mac, nil
case errors.Is(err, bufio.ErrBufferFull):
// We can only peek < len(Sample{}) data.
// We can only peek < len(Sample{}) data.
// fallback to sampledConn
default:
return Sample{}, nil, err
Expand Down
25 changes: 13 additions & 12 deletions p2p/transport/tcpreuse/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,21 @@ func NewConnMgr(disableReuseport bool) *ConnMgr {
}
}

func (t *ConnMgr) maListen(laddr ma.Multiaddr) (manet.Listener, error) {
func (t *ConnMgr) maListen(listenAddr ma.Multiaddr) (manet.Listener, error) {
if t.useReuseport() {
return t.reuse.Listen(laddr)
return t.reuse.Listen(listenAddr)
} else {
return manet.Listen(laddr)
return manet.Listen(listenAddr)
}
}

func (t *ConnMgr) useReuseport() bool {
return !t.disableReuseport && ReuseportIsAvailable()
}

func getTCPAddr(laddr ma.Multiaddr) (ma.Multiaddr, error) {
func getTCPAddr(listenAddr ma.Multiaddr) (ma.Multiaddr, error) {
haveTCP := false
addr, _ := ma.SplitFunc(laddr, func(c ma.Component) bool {
addr, _ := ma.SplitFunc(listenAddr, func(c ma.Component) bool {
if haveTCP {
return true
}
Expand All @@ -57,23 +57,23 @@ func getTCPAddr(laddr ma.Multiaddr) (ma.Multiaddr, error) {
return false
})
if !haveTCP {
return nil, fmt.Errorf("invalid listen addr %s, need tcp address", laddr)
return nil, fmt.Errorf("invalid listen addr %s, need tcp address", listenAddr)
}
return addr, nil
}

func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType DemultiplexedConnType) (manet.Listener, error) {
func (t *ConnMgr) DemultiplexedListen(listenAddr ma.Multiaddr, connType DemultiplexedConnType) (manet.Listener, error) {
if !connType.IsKnown() {
return nil, fmt.Errorf("unknown connection type: %s", connType)
}
laddr, err := getTCPAddr(laddr)
listenAddr, err := getTCPAddr(listenAddr)
if err != nil {
return nil, err
}

t.mx.Lock()
defer t.mx.Unlock()
ml, ok := t.listeners[laddr.String()]
ml, ok := t.listeners[listenAddr.String()]
if ok {
dl, err := ml.DemultiplexedListen(connType)
if err != nil {
Expand All @@ -82,7 +82,7 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed
return dl, nil
}

l, err := t.maListen(laddr)
l, err := t.maListen(listenAddr)
if err != nil {
return nil, err
}
Expand All @@ -92,7 +92,7 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed
cancel()
t.mx.Lock()
defer t.mx.Unlock()
delete(t.listeners, laddr.String())
delete(t.listeners, listenAddr.String())
return l.Close()
}
ml = &multiplexedListener{
Expand All @@ -111,7 +111,7 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed
ml.wg.Add(1)
go ml.run()

t.listeners[laddr.String()] = ml
t.listeners[listenAddr.String()] = ml

return dl, nil
}
Expand Down Expand Up @@ -172,6 +172,7 @@ func (m *multiplexedListener) run() error {

m.wg.Add(1)
go func() {
defer func() { <-acceptQueue }()
defer m.wg.Done()
// TODO: if/how do we want to handle stalled connections and stop them from clogging up the pipeline?
// Drop connection because the buffer is full
Expand Down
107 changes: 102 additions & 5 deletions p2p/transport/tcpreuse/listener_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,106 @@
package tcpreuse

import "testing"
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"math/big"
"testing"
"time"

func TestListenerClose(t *testing.T) {
// cm := NewConnMgr(false)
//
// cm.DemultiplexedListen("/")
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/stretchr/testify/require"
)

func selfSignedTLSConfig(t *testing.T) *tls.Config {
t.Helper()
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)

notBefore := time.Now()
notAfter := notBefore.Add(365 * 24 * time.Hour)

serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
require.NoError(t, err)

certTemplate := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"Test"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}

derBytes, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &priv.PublicKey, priv)
require.NoError(t, err)

cert := tls.Certificate{
Certificate: [][]byte{derBytes},
PrivateKey: priv,
}

tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
}
return tlsConfig
}

func getTLSConn(t *testing.T, c manet.Conn) (manet.Conn, error) {
t.Helper()
return manet.WrapNetConn(tls.Server(c, selfSignedTLSConfig(t)))
}

func TestListenerSingle(t *testing.T) {
listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0")
for disableReuseport := range []bool{true, false} {
t.Run(fmt.Sprintf("TLS-reuseport:%v", disableReuseport), func(t *testing.T) {
cm := NewConnMgr(false)
l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS)
require.NoError(t, err)

go func() {
d := tls.Dialer{Config: &tls.Config{InsecureSkipVerify: true}}
for i := 0; i < 100; i++ {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
conn, err := d.DialContext(ctx, l.Addr().Network(), l.Addr().String())
if err != nil {
t.Error("failed to dial", err, i)
return
}
buf := make([]byte, 10)
_, err = conn.Write([]byte("hello"))
if err != nil {
t.Error(err)
}
_, err = conn.Read(buf)
if err == nil {
t.Error("expected EOF got nil")
}
}
}()
for i := 0; i < 100; i++ {
c, err := l.Accept()
require.NoError(t, err)
c, err = getTLSConn(t, c)
require.NoError(t, err)
buf := make([]byte, 10)
n, err := c.Read(buf)
require.NoError(t, err)
require.Equal(t, "hello", string(buf[:n]))
c.Close()
}
})
}
}
4 changes: 2 additions & 2 deletions p2p/transport/websocket/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMg
} else {
var connType tcpreuse.DemultiplexedConnType
if parsed.isWSS {
connType = tcpreuse.TLS
connType = tcpreuse.DemultiplexedConnType_TLS
} else {
connType = tcpreuse.HTTP
connType = tcpreuse.DemultiplexedConnType_HTTP
}
mal, err := sharedTcp.DemultiplexedListen(parsed.restMultiaddr, connType)
if err != nil {
Expand Down

0 comments on commit 8350885

Please sign in to comment.