diff --git a/p2p/protocol/holepunch/holepunch_test.go b/p2p/protocol/holepunch/holepunch_test.go index 29d589cd7a..3275dc49cc 100644 --- a/p2p/protocol/holepunch/holepunch_test.go +++ b/p2p/protocol/holepunch/holepunch_test.go @@ -13,6 +13,7 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/proto" relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" @@ -511,3 +512,57 @@ func mkHostWithHolePunchSvc(t *testing.T, opts ...holepunch.Option) (host.Host, require.NoError(t, err) return h, hps } + +func TestWebRTCDirectConnect(t *testing.T) { + relay1, err := libp2p.New() + require.NoError(t, err) + + _, err = relayv2.New(relay1) + require.NoError(t, err) + + relay1info := peer.AddrInfo{ + ID: relay1.ID(), + Addrs: relay1.Addrs(), + } + + h1, err := libp2p.New( + libp2p.NoListenAddrs, + libp2p.EnableRelay(), + libp2p.EnableWebRTCPrivate(nil), + libp2p.EnableHolePunching(), + ) + require.NoError(t, err) + + h2, err := libp2p.New( + libp2p.NoListenAddrs, + libp2p.EnableRelay(), + libp2p.EnableWebRTCPrivate(nil), + ) + require.NoError(t, err) + + err = h2.Connect(context.Background(), relay1info) + require.NoError(t, err) + + _, err = client.Reserve(context.Background(), h2, relay1info) + require.NoError(t, err) + + webrtcAddr := ma.StringCast(relay1info.Addrs[0].String() + "/p2p/" + relay1info.ID.String() + "/p2p-circuit/webrtc") + relayAddrs := ma.StringCast(relay1info.Addrs[0].String() + "/p2p/" + relay1info.ID.String() + "/p2p-circuit/") + h1.Peerstore().AddAddrs(h2.ID(), []ma.Multiaddr{webrtcAddr, relayAddrs}, peerstore.TempAddrTTL) + + err = h1.Connect(context.Background(), peer.AddrInfo{ID: h2.ID()}) + require.NoError(t, err) + require.Eventually( + t, + func() bool { + for _, c := range h1.Network().ConnsToPeer(h2.ID()) { + if !c.Stat().Transient { + return true + } + } + return false + }, + 5*time.Second, + 100*time.Millisecond, + ) +} diff --git a/p2p/protocol/holepunch/holepuncher.go b/p2p/protocol/holepunch/holepuncher.go index b651bd7822..2550cfbd6e 100644 --- a/p2p/protocol/holepunch/holepuncher.go +++ b/p2p/protocol/holepunch/holepuncher.go @@ -108,6 +108,9 @@ func (hp *holePuncher) directConnect(rp peer.ID) error { // short-circuit hole punching if a direct dial works. // attempt a direct connection ONLY if we have a public address for the remote peer for _, a := range hp.host.Peerstore().Addrs(rp) { + // Here we consider /webrtc addresses as relay addresses and skip them as they're + // also holepunched. We will dial the /webrtc addresses along with other addresses + // obtained in DCUtR if manet.IsPublicAddr(a) && !isRelayAddress(a) { forceDirectConnCtx := network.WithForceDirectDial(hp.ctx, "hole-punching") dialCtx, cancel := context.WithTimeout(forceDirectConnCtx, dialTimeout) diff --git a/p2p/protocol/holepunch/svc.go b/p2p/protocol/holepunch/svc.go index 47bf434fb1..4b19e5f811 100644 --- a/p2p/protocol/holepunch/svc.go +++ b/p2p/protocol/holepunch/svc.go @@ -84,6 +84,8 @@ func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service, return nil, err } } + s.host.Network().Notify(s) + s.tracer.Start() s.refCount.Add(1) @@ -283,3 +285,39 @@ func (s *Service) DirectConnect(p peer.ID) error { s.holePuncherMx.Unlock() return holePuncher.DirectConnect(p) } + +var _ network.Notifiee = &Service{} + +func (s *Service) Connected(_ network.Network, conn network.Conn) { + // Dial /webrtc address if it's a relay connection to a browser node + if conn.Stat().Direction == network.DirOutbound && conn.Stat().Transient { + s.refCount.Add(1) + go func() { + defer s.refCount.Done() + select { + // waiting for Identify here will allow us to access the peer's public and observed addresses + // that we can dial to for a hole punch. + case <-s.ids.IdentifyWait(conn): + case <-s.ctx.Done(): + return + } + p := conn.RemotePeer() + // Peer supports DCUtR, let it trigger holepunch + if protos, err := s.host.Peerstore().SupportsProtocols(p, Protocol); err == nil && len(protos) > 0 { + return + } + // No DCUtR support, connect with peer over /webrtc + for _, addr := range s.host.Peerstore().Addrs(p) { + if _, err := addr.ValueForProtocol(ma.P_WEBRTC); err == nil { + ctx := network.WithForceDirectDial(s.ctx, "webrtc holepunch") + s.host.Connect(ctx, peer.AddrInfo{ID: p}) // address is already in peerstore + return + } + } + }() + } +} + +func (*Service) Disconnected(_ network.Network, v network.Conn) {} +func (*Service) Listen(n network.Network, a ma.Multiaddr) {} +func (*Service) ListenClose(n network.Network, a ma.Multiaddr) {}