Skip to content

Commit

Permalink
fix(transport): make connection multiaddrs match the full multiaddr i…
Browse files Browse the repository at this point in the history
…ncluding sni and certhash components
  • Loading branch information
aschmahmann committed Oct 8, 2024
1 parent f008fe3 commit 2bf2ad6
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 38 deletions.
14 changes: 6 additions & 8 deletions p2p/test/transport/gating_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func TestInterceptSecuredOutgoing(t *testing.T) {
connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true),
connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) {
// remove the certhash component from WebTransport and WebRTC addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]).String(), addrs.RemoteMultiaddr().String())
require.Equal(t, h2.Addrs()[0].String(), addrs.RemoteMultiaddr().String())
}),
)
err := h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()})
Expand Down Expand Up @@ -135,8 +135,7 @@ func TestInterceptUpgradedOutgoing(t *testing.T) {
connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true),
connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Return(true),
connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), c.RemoteMultiaddr())
require.Equal(t, h2.Addrs()[0], c.RemoteMultiaddr())
require.Equal(t, h1.ID(), c.LocalPeer())
require.Equal(t, h2.ID(), c.RemotePeer())
}))
Expand Down Expand Up @@ -170,12 +169,12 @@ func TestInterceptAccept(t *testing.T) {
// In WebRTC, retransmissions of the STUN packet might cause us to create multiple connections,
// if the first connection attempt is rejected.
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
// remove the certhash component from WebTransport addresses
// remove the certhash component from WebRTC and WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
}).AnyTimes()
} else {
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
// remove the certhash component from WebTransport addresses
// remove the certhash component from WebRTC and WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
})
}
Expand Down Expand Up @@ -213,8 +212,7 @@ func TestInterceptSecuredIncoming(t *testing.T) {
gomock.InOrder(
connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true),
connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
require.Equal(t, h2.Addrs()[0], addrs.LocalMultiaddr())
}),
)
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour)
Expand Down Expand Up @@ -248,7 +246,7 @@ func TestInterceptUpgradedIncoming(t *testing.T) {
connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Return(true),
connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), c.LocalMultiaddr())
require.Equal(t, h2.Addrs()[0], c.LocalMultiaddr())
require.Equal(t, h1.ID(), c.RemotePeer())
require.Equal(t, h2.ID(), c.LocalPeer())
}),
Expand Down
26 changes: 26 additions & 0 deletions p2p/test/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -867,3 +867,29 @@ func TestConnClosedWhenRemoteCloses(t *testing.T) {
})
}
}

func TestConnMatchingAddress(t *testing.T) {
for _, tc := range transportsToTest {
t.Run(tc.Name, func(t *testing.T) {
server := tc.HostGenerator(t, TransportTestCaseOpts{})
client1 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true})
client2 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true})
defer server.Close()
defer client1.Close()
defer client2.Close()

client1.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := client1.Connect(ctx, peer.AddrInfo{ID: server.ID(), Addrs: server.Addrs()})
require.NoError(t, err)

client1Conns := client1.Network().ConnsToPeer(server.ID())
require.Equal(t, 1, len(client1Conns))
remoteMA := client1Conns[0].RemoteMultiaddr()

err = client2.Connect(ctx, peer.AddrInfo{ID: server.ID(), Addrs: []ma.Multiaddr{remoteMA}})
require.NoError(t, err)
})
}
}
3 changes: 1 addition & 2 deletions p2p/transport/webrtc/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,13 @@ func (l *listener) setupConnection(
return nil, err
}

localMultiaddrWithoutCerthash, _ := ma.SplitFunc(l.localMultiaddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_CERTHASH })
conn, err := newConnection(
network.DirInbound,
w.PeerConnection,
l.transport,
scope,
l.transport.localPeerId,
localMultiaddrWithoutCerthash,
l.localMultiaddr,
remotePeer,
remotePubKey,
remoteMultiaddr,
Expand Down
3 changes: 1 addition & 2 deletions p2p/transport/webrtc/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,6 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement
if err != nil {
return nil, err
}
remoteMultiaddrWithoutCerthash, _ := ma.SplitFunc(remoteMultiaddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_CERTHASH })

conn, err := newConnection(
network.DirOutbound,
Expand All @@ -398,7 +397,7 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement
localAddr,
p,
remotePubKey,
remoteMultiaddrWithoutCerthash,
remoteMultiaddr,
w.IncomingDataChannels,
w.PeerConnectionClosedCh,
)
Expand Down
76 changes: 71 additions & 5 deletions p2p/transport/websocket/conn.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package websocket

import (
"fmt"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"io"
"net"
"sync"
Expand All @@ -25,17 +28,72 @@ type Conn struct {
closeOnce sync.Once

readLock, writeLock sync.Mutex

laddr, raddr *Addr
laddrma, raddrma ma.Multiaddr
}

var _ net.Conn = (*Conn)(nil)

// NewConn creates a Conn given a regular gorilla/websocket Conn.
func NewConn(raw *ws.Conn, secure bool) *Conn {
// NewOutboundConn creates an outbound Conn given a regular gorilla/websocket Conn.
func NewOutboundConn(raw *ws.Conn, secure bool, sni string) (*Conn, error) {
return newConn(raw, secure, sni, false)
}

// NewInboundConn creates an inbound Conn given a regular gorilla/websocket Conn.
func NewInboundConn(raw *ws.Conn, secure bool, sni string) (*Conn, error) {
return newConn(raw, secure, sni, true)
}

// newConn creates a Conn given a regular gorilla/websocket Conn.
func newConn(raw *ws.Conn, secure bool, sni string, inbound bool) (*Conn, error) {
laddr := NewAddrWithScheme(raw.LocalAddr().String(), secure)
raddr := NewAddrWithScheme(raw.RemoteAddr().String(), secure)

laddrma, err := manet.FromNetAddr(laddr)
if err != nil {
return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err)
}

raddrma, err := manet.FromNetAddr(raddr)
if err != nil {
return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err)
}

if secure && sni != "" {
var wssMA ma.Multiaddr
if inbound {
wssMA = laddrma
} else {
wssMA = raddrma
}

if withoutWSS := wssMA.Decapsulate(ma.StringCast("/wss")); withoutWSS.Equal(wssMA) {
return nil, fmt.Errorf("missing wss component from converted multiaddr")
} else {
tlsSniWsMa, err := ma.NewMultiaddr(fmt.Sprintf("/tls/sni/%s/ws", sni))
if err != nil {
return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err)
}
wssMA = withoutWSS.Encapsulate(tlsSniWsMa)
}

if inbound {
laddrma = wssMA
} else {
raddrma = wssMA
}
}

return &Conn{
Conn: raw,
secure: secure,
DefaultMessageType: ws.BinaryMessage,
}
laddr: laddr,
raddr: raddr,
laddrma: laddrma,
raddrma: raddrma,
}, nil
}

func (c *Conn) Read(b []byte) (int, error) {
Expand Down Expand Up @@ -122,11 +180,19 @@ func (c *Conn) Close() error {
}

func (c *Conn) LocalAddr() net.Addr {
return NewAddrWithScheme(c.Conn.LocalAddr().String(), c.secure)
return c.laddr
}

func (c *Conn) RemoteAddr() net.Addr {
return NewAddrWithScheme(c.Conn.RemoteAddr().String(), c.secure)
return c.raddr
}

func (c *Conn) LocalMultiaddr() ma.Multiaddr {
return c.laddrma
}

func (c *Conn) RemoteMultiaddr() ma.Multiaddr {
return c.raddrma
}

func (c *Conn) SetDeadline(t time.Time) error {
Expand Down
22 changes: 13 additions & 9 deletions p2p/transport/websocket/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,20 @@ func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

var sni string
if r.TLS != nil {
sni = r.TLS.ServerName
}
mnc, err := NewInboundConn(c, l.isWss, sni)
if err != nil {
_ = c.Close()
return
}

select {
case l.incoming <- NewConn(c, l.isWss):
case l.incoming <- mnc:
case <-l.closed:
c.Close()
mnc.Close()
}
// The connection has been hijacked, it's safe to return.
}
Expand All @@ -126,13 +136,7 @@ func (l *listener) Accept() (manet.Conn, error) {
if !ok {
return nil, transport.ErrListenerClosed
}

mnc, err := manet.WrapNetConn(c)
if err != nil {
c.Close()
return nil, err
}
return mnc, nil
return c, nil
case <-l.closed:
return nil, transport.ErrListenerClosed
}
Expand Down
4 changes: 2 additions & 2 deletions p2p/transport/websocket/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma
}
isWss := wsurl.Scheme == "wss"
dialer := ws.Dialer{HandshakeTimeout: 30 * time.Second}
var sni string
if isWss {
sni := ""
sni, err = raddr.ValueForProtocol(ma.P_SNI)
if err != nil {
sni = ""
Expand Down Expand Up @@ -220,7 +220,7 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma
return nil, err
}

mnc, err := manet.WrapNetConn(NewConn(wscon, isWss))
mnc, err := NewOutboundConn(wscon, isWss, sni)
if err != nil {
wscon.Close()
return nil, err
Expand Down
5 changes: 1 addition & 4 deletions p2p/transport/webtransport/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,7 @@ func (l *listener) Accept() (tpt.CapableConn, error) {
}

func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (*connSecurityMultiaddrs, error) {
local, err := toWebtransportMultiaddr(sess.LocalAddr())
if err != nil {
return nil, fmt.Errorf("error determiniting local addr: %w", err)
}
local := l.Multiaddr()
remote, err := toWebtransportMultiaddr(sess.RemoteAddr())
if err != nil {
return nil, fmt.Errorf("error determiniting remote addr: %w", err)
Expand Down
8 changes: 2 additions & 6 deletions p2p/transport/webtransport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee
if err != nil {
return nil, err
}
sconn, err := t.upgrade(ctx, sess, p, certHashes)
sconn, err := t.upgrade(ctx, sess, p, certHashes, raddr)
if err != nil {
sess.CloseWithError(1, "")
return nil, err
Expand Down Expand Up @@ -230,15 +230,11 @@ func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string
return sess, conn, err
}

func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p peer.ID, certHashes []multihash.DecodedMultihash) (*connSecurityMultiaddrs, error) {
func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p peer.ID, certHashes []multihash.DecodedMultihash, remote ma.Multiaddr) (*connSecurityMultiaddrs, error) {
local, err := toWebtransportMultiaddr(sess.LocalAddr())
if err != nil {
return nil, fmt.Errorf("error determining local addr: %w", err)
}
remote, err := toWebtransportMultiaddr(sess.RemoteAddr())
if err != nil {
return nil, fmt.Errorf("error determining remote addr: %w", err)
}

str, err := sess.OpenStreamSync(ctx)
if err != nil {
Expand Down

0 comments on commit 2bf2ad6

Please sign in to comment.