Skip to content

Commit

Permalink
fix(transport): make connection multiaddrs match the full multiaddr i…
Browse files Browse the repository at this point in the history
…ncluding sni and certhash components
  • Loading branch information
aschmahmann committed Oct 8, 2024
1 parent 2cd6922 commit 080d716
Show file tree
Hide file tree
Showing 10 changed files with 293 additions and 37 deletions.
14 changes: 6 additions & 8 deletions p2p/test/transport/gating_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func TestInterceptSecuredOutgoing(t *testing.T) {
connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true),
connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) {
// remove the certhash component from WebTransport and WebRTC addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]).String(), addrs.RemoteMultiaddr().String())
require.Equal(t, h2.Addrs()[0].String(), addrs.RemoteMultiaddr().String())
}),
)
err := h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()})
Expand Down Expand Up @@ -135,8 +135,7 @@ func TestInterceptUpgradedOutgoing(t *testing.T) {
connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true),
connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Return(true),
connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), c.RemoteMultiaddr())
require.Equal(t, h2.Addrs()[0], c.RemoteMultiaddr())
require.Equal(t, h1.ID(), c.LocalPeer())
require.Equal(t, h2.ID(), c.RemotePeer())
}))
Expand Down Expand Up @@ -170,12 +169,12 @@ func TestInterceptAccept(t *testing.T) {
// In WebRTC, retransmissions of the STUN packet might cause us to create multiple connections,
// if the first connection attempt is rejected.
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
// remove the certhash component from WebTransport addresses
// remove the certhash component from WebRTC and WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
}).AnyTimes()
} else {
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
// remove the certhash component from WebTransport addresses
// remove the certhash component from WebRTC and WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
})
}
Expand Down Expand Up @@ -213,8 +212,7 @@ func TestInterceptSecuredIncoming(t *testing.T) {
gomock.InOrder(
connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true),
connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
require.Equal(t, h2.Addrs()[0], addrs.LocalMultiaddr())
}),
)
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour)
Expand Down Expand Up @@ -248,7 +246,7 @@ func TestInterceptUpgradedIncoming(t *testing.T) {
connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Return(true),
connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), c.LocalMultiaddr())
require.Equal(t, h2.Addrs()[0], c.LocalMultiaddr())
require.Equal(t, h1.ID(), c.RemotePeer())
require.Equal(t, h2.ID(), c.LocalPeer())
}),
Expand Down
99 changes: 97 additions & 2 deletions p2p/test/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,17 @@ package transport_integration
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"fmt"
"io"
"math/big"
"net"
"runtime"
"strings"
Expand All @@ -30,8 +37,9 @@ import (
"github.com/libp2p/go-libp2p/p2p/net/swarm"
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
"github.com/libp2p/go-libp2p/p2p/security/noise"
tls "github.com/libp2p/go-libp2p/p2p/security/tls"
sectls "github.com/libp2p/go-libp2p/p2p/security/tls"
libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc"
"github.com/libp2p/go-libp2p/p2p/transport/websocket"
"go.uber.org/mock/gomock"

ma "github.com/multiformats/go-multiaddr"
Expand All @@ -48,6 +56,7 @@ type TransportTestCaseOpts struct {
NoRcmgr bool
ConnGater connmgr.ConnectionGater
ResourceManager network.ResourceManager
HostSeed string
}

func transformOpts(opts TransportTestCaseOpts) []config.Option {
Expand Down Expand Up @@ -87,7 +96,7 @@ var transportsToTest = []TransportTestCase{
Name: "TCP / TLS / Yamux",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New))
libp2pOpts = append(libp2pOpts, libp2p.Security(sectls.ID, sectls.New))
libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport))
if opts.NoListen {
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs)
Expand All @@ -113,6 +122,26 @@ var transportsToTest = []TransportTestCase{
return h
},
},
{
Name: "Secure WebSocket with CA Certificate",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
wsOpts := []interface{}{websocket.WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})}
if opts.NoListen {
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs)
} else {
dnsName := fmt.Sprintf("example%s.com", opts.HostSeed)
cert, err := generateSelfSignedCert(dnsName)
require.NoError(t, err)
libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings(fmt.Sprintf("/ip4/127.0.0.1/tcp/0/tls/sni/%s/ws", dnsName)))
wsOpts = append(wsOpts, websocket.WithTLSConfig(&tls.Config{Certificates: []tls.Certificate{cert}}))
}
libp2pOpts = append(libp2pOpts, libp2p.Transport(websocket.New, wsOpts...))
h, err := libp2p.New(libp2pOpts...)
require.NoError(t, err)
return h
},
},
{
Name: "QUIC",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
Expand Down Expand Up @@ -158,6 +187,46 @@ var transportsToTest = []TransportTestCase{
},
}

func generateSelfSignedCert(dnsName string) (tls.Certificate, error) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return tls.Certificate{}, err
}

serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return tls.Certificate{}, err
}

template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"My Organization"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour), // Valid for 1 year
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: []string{dnsName},
}

certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return tls.Certificate{}, err
}

certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
privDER, err := x509.MarshalECPrivateKey(priv)
if err != nil {
return tls.Certificate{}, err
}
privPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: privDER})

// Load the certificate and key into tls.Certificate
return tls.X509KeyPair(certPEM, privPEM)
}

func TestPing(t *testing.T) {
for _, tc := range transportsToTest {
t.Run(tc.Name, func(t *testing.T) {
Expand Down Expand Up @@ -798,3 +867,29 @@ func TestConnClosedWhenRemoteCloses(t *testing.T) {
})
}
}

func TestConnMatchingAddress(t *testing.T) {
for _, tc := range transportsToTest {
t.Run(tc.Name, func(t *testing.T) {
server := tc.HostGenerator(t, TransportTestCaseOpts{})
client1 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true})
client2 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true})
defer server.Close()
defer client1.Close()
defer client2.Close()

client1.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := client1.Connect(ctx, peer.AddrInfo{ID: server.ID(), Addrs: server.Addrs()})
require.NoError(t, err)

client1Conns := client1.Network().ConnsToPeer(server.ID())
require.Equal(t, 1, len(client1Conns))
remoteMA := client1Conns[0].RemoteMultiaddr()

err = client2.Connect(ctx, peer.AddrInfo{ID: server.ID(), Addrs: []ma.Multiaddr{remoteMA}})
require.NoError(t, err)
})
}
}
3 changes: 1 addition & 2 deletions p2p/transport/webrtc/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,13 @@ func (l *listener) setupConnection(
return nil, err
}

localMultiaddrWithoutCerthash, _ := ma.SplitFunc(l.localMultiaddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_CERTHASH })
conn, err := newConnection(
network.DirInbound,
w.PeerConnection,
l.transport,
scope,
l.transport.localPeerId,
localMultiaddrWithoutCerthash,
l.localMultiaddr,
remotePeer,
remotePubKey,
remoteMultiaddr,
Expand Down
3 changes: 1 addition & 2 deletions p2p/transport/webrtc/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,6 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement
if err != nil {
return nil, err
}
remoteMultiaddrWithoutCerthash, _ := ma.SplitFunc(remoteMultiaddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_CERTHASH })

conn, err := newConnection(
network.DirOutbound,
Expand All @@ -398,7 +397,7 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement
localAddr,
p,
remotePubKey,
remoteMultiaddrWithoutCerthash,
remoteMultiaddr,
w.IncomingDataChannels,
w.PeerConnectionClosedCh,
)
Expand Down
93 changes: 91 additions & 2 deletions p2p/transport/websocket/conn.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package websocket

import (
"fmt"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"io"
"net"
"sync"
Expand All @@ -25,10 +28,88 @@ type Conn struct {
closeOnce sync.Once

readLock, writeLock sync.Mutex

laddr, raddr *Addr
laddrma, raddrma ma.Multiaddr
}

var _ net.Conn = (*Conn)(nil)

// NewConn creates a Conn given a regular gorilla/websocket Conn.
func NewOutboundConn(raw *ws.Conn, secure bool, sni string) (*Conn, error) {
laddr := NewAddrWithScheme(raw.LocalAddr().String(), secure)
raddr := NewAddrWithScheme(raw.RemoteAddr().String(), secure)

laddrma, err := manet.FromNetAddr(laddr)
if err != nil {
return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err)
}

raddrma, err := manet.FromNetAddr(raddr)
if err != nil {
return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err)
}

if secure {
if withoutWSS := raddrma.Decapsulate(ma.StringCast("/wss")); withoutWSS.Equal(raddrma) {
return nil, fmt.Errorf("missing wss component from converted multiaddr")
} else {
tlsSniWsMa, err := ma.NewMultiaddr(fmt.Sprintf("/tls/sni/%s/ws", sni))
if err != nil {
return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err)
}
raddrma = withoutWSS.Encapsulate(tlsSniWsMa)
}
}

return &Conn{
Conn: raw,
secure: secure,
DefaultMessageType: ws.BinaryMessage,
laddr: laddr,
raddr: raddr,
laddrma: laddrma,
raddrma: raddrma,
}, nil
}

func NewInboundConn(raw *ws.Conn, secure bool, sni string) (*Conn, error) {
laddr := NewAddrWithScheme(raw.LocalAddr().String(), secure)
raddr := NewAddrWithScheme(raw.RemoteAddr().String(), secure)

laddrma, err := manet.FromNetAddr(laddr)
if err != nil {
return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err)
}

raddrma, err := manet.FromNetAddr(raddr)
if err != nil {
return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err)
}

if secure {
if withoutWSS := laddrma.Decapsulate(ma.StringCast("/wss")); withoutWSS.Equal(laddrma) {
return nil, fmt.Errorf("missing wss component from converted multiaddr")
} else {
tlsSniWsMa, err := ma.NewMultiaddr(fmt.Sprintf("/tls/sni/%s/ws", sni))
if err != nil {
return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err)
}
laddrma = withoutWSS.Encapsulate(tlsSniWsMa)
}
}

return &Conn{
Conn: raw,
secure: secure,
DefaultMessageType: ws.BinaryMessage,
laddr: laddr,
raddr: raddr,
laddrma: laddrma,
raddrma: raddrma,
}, nil
}

// NewConn creates a Conn given a regular gorilla/websocket Conn.
func NewConn(raw *ws.Conn, secure bool) *Conn {
return &Conn{
Expand Down Expand Up @@ -122,11 +203,19 @@ func (c *Conn) Close() error {
}

func (c *Conn) LocalAddr() net.Addr {
return NewAddrWithScheme(c.Conn.LocalAddr().String(), c.secure)
return c.laddr
}

func (c *Conn) RemoteAddr() net.Addr {
return NewAddrWithScheme(c.Conn.RemoteAddr().String(), c.secure)
return c.raddr
}

func (c *Conn) LocalMultiaddr() ma.Multiaddr {
return c.laddrma
}

func (c *Conn) RemoteMultiaddr() ma.Multiaddr {
return c.raddrma
}

func (c *Conn) SetDeadline(t time.Time) error {
Expand Down
22 changes: 13 additions & 9 deletions p2p/transport/websocket/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,20 @@ func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

var sni string
if r.TLS != nil {
sni = r.TLS.ServerName
}
mnc, err := NewInboundConn(c, l.isWss, sni)
if err != nil {
_ = c.Close()
return
}

select {
case l.incoming <- NewConn(c, l.isWss):
case l.incoming <- mnc:
case <-l.closed:
c.Close()
mnc.Close()
}
// The connection has been hijacked, it's safe to return.
}
Expand All @@ -126,13 +136,7 @@ func (l *listener) Accept() (manet.Conn, error) {
if !ok {
return nil, transport.ErrListenerClosed
}

mnc, err := manet.WrapNetConn(c)
if err != nil {
c.Close()
return nil, err
}
return mnc, nil
return c, nil
case <-l.closed:
return nil, transport.ErrListenerClosed
}
Expand Down
Loading

0 comments on commit 080d716

Please sign in to comment.