From 8b7d1518dbc5cbe2c7528f90a77b7a6e3c7174f7 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 16 Sep 2022 08:52:51 +0300 Subject: [PATCH] swarm: fix selection of transport for dialing (#1653) With WebTransport's /webtransport/certhash/xyz addresses, the assumption that the last component of a multiaddr identifies the transport to use for dialing doesn't hold any more. Note that WebRTC will probably also use the certhash multiaddr component to encode its certificate hashes. --- p2p/net/swarm/swarm_addr_test.go | 46 ++++++++++++++++++++++++++++++++ p2p/net/swarm/swarm_transport.go | 18 ++++++------- 2 files changed, 54 insertions(+), 10 deletions(-) diff --git a/p2p/net/swarm/swarm_addr_test.go b/p2p/net/swarm/swarm_addr_test.go index af98400a4e..56a2740b35 100644 --- a/p2p/net/swarm/swarm_addr_test.go +++ b/p2p/net/swarm/swarm_addr_test.go @@ -2,13 +2,25 @@ package swarm_test import ( "context" + "fmt" "testing" + "github.com/libp2p/go-libp2p/core/peer" + + ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/test" + "github.com/libp2p/go-libp2p/p2p/net/swarm" swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" + circuitv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" + quic "github.com/libp2p/go-libp2p/p2p/transport/quic" + "github.com/libp2p/go-libp2p/p2p/transport/tcp" + webtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport" + "github.com/minio/sha256-simd" ma "github.com/multiformats/go-multiaddr" + "github.com/multiformats/go-multibase" + "github.com/multiformats/go-multihash" "github.com/stretchr/testify/require" ) @@ -56,3 +68,37 @@ func TestAddressesWithoutListening(t *testing.T) { require.NoError(t, err) require.Empty(t, a1, "expected to be listening on no addresses") } + +func TestDialAddressSelection(t *testing.T) { + priv, _, err := test.RandTestKeyPair(ic.Ed25519, 256) + require.NoError(t, err) + id, err := peer.IDFromPrivateKey(priv) + require.NoError(t, err) + s, err := swarm.NewSwarm("local", nil) + require.NoError(t, err) + + tcpTr, err := tcp.NewTCPTransport(nil, nil) + require.NoError(t, err) + require.NoError(t, s.AddTransport(tcpTr)) + quicTr, err := quic.NewTransport(priv, nil, nil, nil) + require.NoError(t, err) + require.NoError(t, s.AddTransport(quicTr)) + webtransportTr, err := webtransport.New(priv, nil, nil) + require.NoError(t, err) + require.NoError(t, s.AddTransport(webtransportTr)) + h := sha256.Sum256([]byte("foo")) + hash, err := multihash.Encode(h[:], multihash.SHA2_256) + require.NoError(t, err) + certHash, err := multibase.Encode(multibase.Base58BTC, hash) + require.NoError(t, err) + circuitTr, err := circuitv2.New(nil, nil) + require.NoError(t, err) + require.NoError(t, s.AddTransport(circuitTr)) + + require.Equal(t, tcpTr, s.TransportForDialing(ma.StringCast("/ip4/127.0.0.1/tcp/1234"))) + require.Equal(t, quicTr, s.TransportForDialing(ma.StringCast("/ip4/127.0.0.1/udp/1234/quic"))) + require.Equal(t, circuitTr, s.TransportForDialing(ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/1234/quic/p2p-circuit/p2p/%s", id)))) + require.Equal(t, webtransportTr, s.TransportForDialing(ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/%s", certHash)))) + require.Nil(t, s.TransportForDialing(ma.StringCast("/ip4/1.2.3.4"))) + require.Nil(t, s.TransportForDialing(ma.StringCast("/ip4/1.2.3.4/tcp/443/ws"))) +} diff --git a/p2p/net/swarm/swarm_transport.go b/p2p/net/swarm/swarm_transport.go index 96a3a8e2cd..924f0384aa 100644 --- a/p2p/net/swarm/swarm_transport.go +++ b/p2p/net/swarm/swarm_transport.go @@ -19,6 +19,7 @@ func (s *Swarm) TransportForDialing(a ma.Multiaddr) transport.Transport { s.transports.RLock() defer s.transports.RUnlock() + if len(s.transports.m) == 0 { // make sure we're not just shutting down. if s.transports.m != nil { @@ -26,18 +27,15 @@ func (s *Swarm) TransportForDialing(a ma.Multiaddr) transport.Transport { } return nil } - - for _, p := range protocols { - transport, ok := s.transports.m[p.Code] - if !ok { - continue - } - if transport.Proxy() { - return transport + if isRelayAddr(a) { + return s.transports.m[ma.P_CIRCUIT] + } + for _, t := range s.transports.m { + if t.CanDial(a) { + return t } } - - return s.transports.m[protocols[len(protocols)-1].Code] + return nil } // TransportForListening retrieves the appropriate transport for listening on