From ec8ebaabf12c7fb8a6107556fc5bb7ed6ac6600e Mon Sep 17 00:00:00 2001 From: sukun Date: Thu, 17 Oct 2024 04:41:15 +0530 Subject: [PATCH] holepuncher: pass address function in constructor (#2979) * holepunch: pass address function in constructor * nit * Remove getPublicAddrs --------- Co-authored-by: Marco Munizaga --- p2p/host/basic/basic_host.go | 11 +++- p2p/protocol/holepunch/holepunch_test.go | 15 +++-- p2p/protocol/holepunch/holepuncher.go | 23 +++---- p2p/protocol/holepunch/svc.go | 80 +++++++----------------- p2p/protocol/holepunch/util.go | 9 +-- 5 files changed, 56 insertions(+), 82 deletions(-) diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index f8ca3a179d..1ec51021f2 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -274,7 +274,16 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { opts.HolePunchingOptions = append(hpOpts, opts.HolePunchingOptions...) } - h.hps, err = holepunch.NewService(h, h.ids, opts.HolePunchingOptions...) + h.hps, err = holepunch.NewService(h, h.ids, func() []ma.Multiaddr { + addrs := h.AllAddrs() + if opts.AddrsFactory != nil { + addrs = opts.AddrsFactory(addrs) + } + // AllAddrs may ignore observed addresses in favour of NAT mappings. Use both for hole punching. + addrs = append(addrs, h.ids.OwnObservedAddrs()...) + addrs = ma.Unique(addrs) + return slices.DeleteFunc(addrs, func(a ma.Multiaddr) bool { return !manet.IsPublicAddr(a) }) + }, opts.HolePunchingOptions...) if err != nil { return nil, fmt.Errorf("failed to create hole punch service: %w", err) } diff --git a/p2p/protocol/holepunch/holepunch_test.go b/p2p/protocol/holepunch/holepunch_test.go index 23593c7970..00a76023ed 100644 --- a/p2p/protocol/holepunch/holepunch_test.go +++ b/p2p/protocol/holepunch/holepunch_test.go @@ -3,6 +3,7 @@ package holepunch_test import ( "context" "net" + "slices" "sync" "testing" "time" @@ -94,7 +95,6 @@ func TestNoHolePunchIfDirectConnExists(t *testing.T) { require.GreaterOrEqual(t, nc1, 1) nc2 := len(h2.Network().ConnsToPeer(h1.ID())) require.GreaterOrEqual(t, nc2, 1) - require.NoError(t, hps.DirectConnect(h2.ID())) require.Len(t, h1.Network().ConnsToPeer(h2.ID()), nc1) require.Len(t, h2.Network().ConnsToPeer(h1.ID()), nc2) @@ -473,8 +473,7 @@ func makeRelayedHosts(t *testing.T, h1opt, h2opt []holepunch.Option, addHolePunc hps = addHolePunchService(t, h2, h2opt...) } - // h1 has a relay addr - // h2 should connect to the relay addr + // h2 has a relay addr var raddr ma.Multiaddr for _, a := range h2.Addrs() { if _, err := a.ValueForProtocol(ma.P_CIRCUIT); err == nil { @@ -483,6 +482,7 @@ func makeRelayedHosts(t *testing.T, h1opt, h2opt []holepunch.Option, addHolePunc } } require.NotEmpty(t, raddr) + // h1 should connect to the relay addr require.NoError(t, h1.Connect(context.Background(), peer.AddrInfo{ ID: h2.ID(), Addrs: []ma.Multiaddr{raddr}, @@ -492,7 +492,11 @@ func makeRelayedHosts(t *testing.T, h1opt, h2opt []holepunch.Option, addHolePunc func addHolePunchService(t *testing.T, h host.Host, opts ...holepunch.Option) *holepunch.Service { t.Helper() - hps, err := holepunch.NewService(h, newMockIDService(t, h), opts...) + hps, err := holepunch.NewService(h, newMockIDService(t, h), func() []ma.Multiaddr { + addrs := h.Addrs() + addrs = slices.DeleteFunc(addrs, func(a ma.Multiaddr) bool { return !manet.IsPublicAddr(a) }) + return append(addrs, ma.StringCast("/ip4/1.2.3.4/tcp/1234")) + }, opts...) require.NoError(t, err) return hps } @@ -505,7 +509,6 @@ func mkHostWithHolePunchSvc(t *testing.T, opts ...holepunch.Option) (host.Host, libp2p.ResourceManager(&network.NullResourceManager{}), ) require.NoError(t, err) - hps, err := holepunch.NewService(h, newMockIDService(t, h), opts...) - require.NoError(t, err) + hps := addHolePunchService(t, h, opts...) return h, hps } diff --git a/p2p/protocol/holepunch/holepuncher.go b/p2p/protocol/holepunch/holepuncher.go index a30e653761..20d0558fc5 100644 --- a/p2p/protocol/holepunch/holepuncher.go +++ b/p2p/protocol/holepunch/holepuncher.go @@ -37,7 +37,8 @@ type holePuncher struct { host host.Host refCount sync.WaitGroup - ids identify.IDService + ids identify.IDService + listenAddrs func() []ma.Multiaddr // active hole punches for deduplicating activeMx sync.Mutex @@ -50,13 +51,14 @@ type holePuncher struct { filter AddrFilter } -func newHolePuncher(h host.Host, ids identify.IDService, tracer *tracer, filter AddrFilter) *holePuncher { +func newHolePuncher(h host.Host, ids identify.IDService, listenAddrs func() []ma.Multiaddr, tracer *tracer, filter AddrFilter) *holePuncher { hp := &holePuncher{ - host: h, - ids: ids, - active: make(map[peer.ID]struct{}), - tracer: tracer, - filter: filter, + host: h, + ids: ids, + active: make(map[peer.ID]struct{}), + tracer: tracer, + filter: filter, + listenAddrs: listenAddrs, } hp.ctx, hp.ctxCancel = context.WithCancel(context.Background()) h.Network().Notify((*netNotifiee)(hp)) @@ -102,16 +104,15 @@ func (hp *holePuncher) directConnect(rp peer.ID) error { if getDirectConnection(hp.host, rp) != nil { return nil } - // short-circuit hole punching if a direct dial works. // attempt a direct connection ONLY if we have a public address for the remote peer for _, a := range hp.host.Peerstore().Addrs(rp) { - if manet.IsPublicAddr(a) && !isRelayAddress(a) { + if !isRelayAddress(a) && manet.IsPublicAddr(a) { forceDirectConnCtx := network.WithForceDirectDial(hp.ctx, "hole-punching") dialCtx, cancel := context.WithTimeout(forceDirectConnCtx, dialTimeout) tstart := time.Now() - // This dials *all* public addresses from the peerstore. + // This dials *all* addresses, public and private, from the peerstore. err := hp.host.Connect(dialCtx, peer.AddrInfo{ID: rp}) dt := time.Since(tstart) cancel() @@ -206,7 +207,7 @@ func (hp *holePuncher) initiateHolePunchImpl(str network.Stream) ([]ma.Multiaddr str.SetDeadline(time.Now().Add(StreamTimeout)) // send a CONNECT and start RTT measurement. - obsAddrs := removeRelayAddrs(hp.ids.OwnObservedAddrs()) + obsAddrs := removeRelayAddrs(hp.listenAddrs()) if hp.filter != nil { obsAddrs = hp.filter.FilterLocal(str.Conn().RemotePeer(), obsAddrs) } diff --git a/p2p/protocol/holepunch/svc.go b/p2p/protocol/holepunch/svc.go index eb8ad9fd38..2e6fdd1a6a 100644 --- a/p2p/protocol/holepunch/svc.go +++ b/p2p/protocol/holepunch/svc.go @@ -8,18 +8,15 @@ import ( "time" logging "github.com/ipfs/go-log/v2" - "github.com/libp2p/go-libp2p/core/event" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/protocol" - "github.com/libp2p/go-libp2p/p2p/host/eventbus" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch/pb" "github.com/libp2p/go-libp2p/p2p/protocol/identify" "github.com/libp2p/go-msgio/pbio" ma "github.com/multiformats/go-multiaddr" - manet "github.com/multiformats/go-multiaddr/net" ) // Protocol is the libp2p protocol for Hole Punching. @@ -47,7 +44,13 @@ type Service struct { ctxCancel context.CancelFunc host host.Host - ids identify.IDService + // ids helps with connection reversal. We wait for identify to complete and attempt + // a direct connection to the peer if it's publicly reachable. + ids identify.IDService + // listenAddrs provides the addresses for the host to be used for hole punching. We use this + // and not host.Addrs because host.Addrs might remove public unreachable address and only advertise + // publicly reachable relay addresses. + listenAddrs func() []ma.Multiaddr holePuncherMx sync.Mutex holePuncher *holePuncher @@ -65,7 +68,9 @@ type Service struct { // no matter if they are behind a NAT / firewall or not. // The Service handles DCUtR streams (which are initiated from the node behind // a NAT / Firewall once we establish a connection to them through a relay. -func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service, error) { +// +// listenAddrs MUST only return public addresses. +func NewService(h host.Host, ids identify.IDService, listenAddrs func() []ma.Multiaddr, opts ...Option) (*Service, error) { if ids == nil { return nil, errors.New("identify service can't be nil") } @@ -76,6 +81,7 @@ func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service, ctxCancel: cancel, host: h, ids: ids, + listenAddrs: listenAddrs, hasPublicAddrsChan: make(chan struct{}), } @@ -88,18 +94,18 @@ func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service, s.tracer.Start() s.refCount.Add(1) - go s.watchForPublicAddr() + go s.waitForPublicAddr() return s, nil } -func (s *Service) watchForPublicAddr() { +func (s *Service) waitForPublicAddr() { defer s.refCount.Done() log.Debug("waiting until we have at least one public address", "peer", s.host.ID()) // TODO: We should have an event here that fires when identify discovers a new - // address (and when autonat confirms that address). + // address. // As we currently don't have an event like this, just check our observed addresses // regularly (exponential backoff starting at 250 ms, capped at 5s). duration := 250 * time.Millisecond @@ -107,7 +113,7 @@ func (s *Service) watchForPublicAddr() { t := time.NewTimer(duration) defer t.Stop() for { - if len(s.getPublicAddrs()) > 0 { + if len(s.listenAddrs()) > 0 { log.Debug("Host now has a public address. Starting holepunch protocol.") s.host.SetStreamHandler(Protocol, s.handleNewStream) break @@ -125,36 +131,20 @@ func (s *Service) watchForPublicAddr() { } } - // Only start the holePuncher if we're behind a NAT / firewall. - sub, err := s.host.EventBus().Subscribe(&event.EvtLocalReachabilityChanged{}, eventbus.Name("holepunch")) - if err != nil { - log.Debugf("failed to subscripe to Reachability event: %s", err) + s.holePuncherMx.Lock() + if s.ctx.Err() != nil { + // service is closed return } - defer sub.Close() - for { - select { - case <-s.ctx.Done(): - return - case e, ok := <-sub.Out(): - if !ok { - return - } - if e.(event.EvtLocalReachabilityChanged).Reachability != network.ReachabilityPrivate { - continue - } - s.holePuncherMx.Lock() - s.holePuncher = newHolePuncher(s.host, s.ids, s.tracer, s.filter) - s.holePuncherMx.Unlock() - close(s.hasPublicAddrsChan) - return - } - } + s.holePuncher = newHolePuncher(s.host, s.ids, s.listenAddrs, s.tracer, s.filter) + s.holePuncherMx.Unlock() + close(s.hasPublicAddrsChan) } // Close closes the Hole Punch Service. func (s *Service) Close() error { var err error + s.ctxCancel() s.holePuncherMx.Lock() if s.holePuncher != nil { err = s.holePuncher.Close() @@ -162,7 +152,6 @@ func (s *Service) Close() error { s.holePuncherMx.Unlock() s.tracer.Close() s.host.RemoveStreamHandler(Protocol) - s.ctxCancel() s.refCount.Wait() return err } @@ -172,7 +161,7 @@ func (s *Service) incomingHolePunch(str network.Stream) (rtt time.Duration, remo if !isRelayAddress(str.Conn().RemoteMultiaddr()) { return 0, nil, nil, fmt.Errorf("received hole punch stream: %s", str.Conn().RemoteMultiaddr()) } - ownAddrs = s.getPublicAddrs() + ownAddrs = s.listenAddrs() if s.filter != nil { ownAddrs = s.filter.FilterLocal(str.Conn().RemotePeer(), ownAddrs) } @@ -275,29 +264,6 @@ func (s *Service) handleNewStream(str network.Stream) { s.tracer.HolePunchFinished("receiver", 1, addrs, ownAddrs, getDirectConnection(s.host, rp)) } -// getPublicAddrs returns public observed and interface addresses -func (s *Service) getPublicAddrs() []ma.Multiaddr { - addrs := removeRelayAddrs(s.ids.OwnObservedAddrs()) - - interfaceListenAddrs, err := s.host.Network().InterfaceListenAddresses() - if err != nil { - log.Debugf("failed to get to get InterfaceListenAddresses: %s", err) - } else { - addrs = append(addrs, interfaceListenAddrs...) - } - - addrs = ma.Unique(addrs) - - publicAddrs := make([]ma.Multiaddr, 0, len(addrs)) - - for _, addr := range addrs { - if manet.IsPublicAddr(addr) { - publicAddrs = append(publicAddrs, addr) - } - } - return publicAddrs -} - // DirectConnect is only exposed for testing purposes. // TODO: find a solution for this. func (s *Service) DirectConnect(p peer.ID) error { diff --git a/p2p/protocol/holepunch/util.go b/p2p/protocol/holepunch/util.go index 947b1ffd82..c0f34d0928 100644 --- a/p2p/protocol/holepunch/util.go +++ b/p2p/protocol/holepunch/util.go @@ -2,6 +2,7 @@ package holepunch import ( "context" + "slices" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" @@ -11,13 +12,7 @@ import ( ) func removeRelayAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { - result := make([]ma.Multiaddr, 0, len(addrs)) - for _, addr := range addrs { - if !isRelayAddress(addr) { - result = append(result, addr) - } - } - return result + return slices.DeleteFunc(addrs, isRelayAddress) } func isRelayAddress(a ma.Multiaddr) bool {