From e45fe9e746767a5e4cb990efb161cd3567e711c8 Mon Sep 17 00:00:00 2001 From: sukun Date: Tue, 3 Dec 2024 20:21:18 +0530 Subject: [PATCH] add tests for present behavior --- p2p/host/basic/address_service.go | 16 +-- p2p/host/basic/address_service_test.go | 174 +++++++++++++++++++++++++ p2p/host/basic/basic_host.go | 99 +++++++------- 3 files changed, 229 insertions(+), 60 deletions(-) diff --git a/p2p/host/basic/address_service.go b/p2p/host/basic/address_service.go index 73d359c857..bf102cd886 100644 --- a/p2p/host/basic/address_service.go +++ b/p2p/host/basic/address_service.go @@ -9,7 +9,6 @@ import ( "github.com/libp2p/go-libp2p/core/event" "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-libp2p/core/record" "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/host/basic/internal/backoff" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" @@ -19,8 +18,6 @@ import ( manet "github.com/multiformats/go-multiaddr/net" ) -type peerRecordFunc func([]ma.Multiaddr) (*record.Envelope, error) - type observedAddrsService interface { OwnObservedAddrs() []ma.Multiaddr ObservedAddrsFor(local ma.Multiaddr) []ma.Multiaddr @@ -34,12 +31,13 @@ type addressService struct { addrsChangeChan chan struct{} addrsUpdated chan struct{} autoRelayAddrsSub event.Subscription - autoRelayAddrs func() []ma.Multiaddr - reachability func() network.Reachability - ifaceAddrs *interfaceAddrsCache - wg sync.WaitGroup - ctx context.Context - ctxCancel context.CancelFunc + // There are wrapped in to functions for mocking + autoRelayAddrs func() []ma.Multiaddr + reachability func() network.Reachability + ifaceAddrs *interfaceAddrsCache + wg sync.WaitGroup + ctx context.Context + ctxCancel context.CancelFunc } func NewAddressService(h *BasicHost, natmgr func(network.Network) NATManager, diff --git a/p2p/host/basic/address_service_test.go b/p2p/host/basic/address_service_test.go index 6ee611c8a2..d86cb162e7 100644 --- a/p2p/host/basic/address_service_test.go +++ b/p2p/host/basic/address_service_test.go @@ -2,7 +2,10 @@ package basichost import ( "testing" + "time" + "github.com/libp2p/go-libp2p/core/network" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" "github.com/stretchr/testify/require" @@ -96,3 +99,174 @@ func TestAppendNATAddrs(t *testing.T) { }) } } + +type mockNatManager struct { + GetMappingFunc func(addr ma.Multiaddr) ma.Multiaddr + HasDiscoveredNATFunc func() bool +} + +func (m *mockNatManager) Close() error { + return nil +} + +func (m *mockNatManager) GetMapping(addr ma.Multiaddr) ma.Multiaddr { + return m.GetMappingFunc(addr) +} + +func (m *mockNatManager) HasDiscoveredNAT() bool { + return m.HasDiscoveredNATFunc() +} + +var _ NATManager = &mockNatManager{} + +type mockObservedAddrs struct { + OwnObservedAddrsFunc func() []ma.Multiaddr + ObservedAddrsForFunc func(ma.Multiaddr) []ma.Multiaddr +} + +func (m *mockObservedAddrs) OwnObservedAddrs() []ma.Multiaddr { + return m.OwnObservedAddrsFunc() +} + +func (m *mockObservedAddrs) ObservedAddrsFor(local ma.Multiaddr) []ma.Multiaddr { + return m.ObservedAddrsForFunc(local) +} + +func TestAddressService(t *testing.T) { + getAddrService := func() *addressService { + h, err := NewHost(swarmt.GenSwarm(t), &HostOpts{DisableIdentifyAddressDiscovery: true}) + require.NoError(t, err) + t.Cleanup(func() { h.Close() }) + + as := h.addressService + return as + } + + t.Run("NAT Address", func(t *testing.T) { + as := getAddrService() + as.natmgr = &mockNatManager{ + HasDiscoveredNATFunc: func() bool { return true }, + GetMappingFunc: func(addr ma.Multiaddr) ma.Multiaddr { + if _, err := addr.ValueForProtocol(ma.P_UDP); err == nil { + return ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1") + } + return nil + }, + } + require.Contains(t, as.Addrs(), ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1")) + }) + + t.Run("NAT And Observed Address", func(t *testing.T) { + as := getAddrService() + as.natmgr = &mockNatManager{ + HasDiscoveredNATFunc: func() bool { return true }, + GetMappingFunc: func(addr ma.Multiaddr) ma.Multiaddr { + if _, err := addr.ValueForProtocol(ma.P_UDP); err == nil { + return ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1") + } + return nil + }, + } + as.observedAddrsService = &mockObservedAddrs{ + ObservedAddrsForFunc: func(addr ma.Multiaddr) []ma.Multiaddr { + if _, err := addr.ValueForProtocol(ma.P_TCP); err == nil { + return []ma.Multiaddr{ma.StringCast("/ip4/2.2.2.2/tcp/1")} + } + return nil + }, + } + require.Contains(t, as.Addrs(), ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1")) + require.Contains(t, as.Addrs(), ma.StringCast("/ip4/2.2.2.2/tcp/1")) + }) + t.Run("Only Observed Address", func(t *testing.T) { + as := getAddrService() + as.natmgr = nil + as.observedAddrsService = &mockObservedAddrs{ + ObservedAddrsForFunc: func(addr ma.Multiaddr) []ma.Multiaddr { + if _, err := addr.ValueForProtocol(ma.P_TCP); err == nil { + return []ma.Multiaddr{ma.StringCast("/ip4/2.2.2.2/tcp/1")} + } + return nil + }, + OwnObservedAddrsFunc: func() []ma.Multiaddr { + return []ma.Multiaddr{ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1")} + }, + } + require.NotContains(t, as.Addrs(), ma.StringCast("/ip4/2.2.2.2/tcp/1")) + require.Contains(t, as.Addrs(), ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1")) + }) + t.Run("Public Addrs Removed When Private", func(t *testing.T) { + as := getAddrService() + as.natmgr = nil + as.observedAddrsService = &mockObservedAddrs{ + OwnObservedAddrsFunc: func() []ma.Multiaddr { + return []ma.Multiaddr{ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1")} + }, + } + as.reachability = func() network.Reachability { + return network.ReachabilityPrivate + } + relayAddr := ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1/p2p/QmdXGaeGiVA745XorV1jr11RHxB9z4fqykm6xCUPX1aTJo/p2p-circuit") + as.autoRelayAddrs = func() []ma.Multiaddr { + return []ma.Multiaddr{relayAddr} + } + require.NotContains(t, as.Addrs(), ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1")) + require.Contains(t, as.Addrs(), relayAddr) + require.Contains(t, as.AllAddrs(), ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1")) + }) + + t.Run("AddressFactory gets relay addresses", func(t *testing.T) { + as := getAddrService() + as.natmgr = nil + as.observedAddrsService = &mockObservedAddrs{ + OwnObservedAddrsFunc: func() []ma.Multiaddr { + return []ma.Multiaddr{ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1")} + }, + } + as.reachability = func() network.Reachability { + return network.ReachabilityPrivate + } + relayAddr := ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1/p2p/QmdXGaeGiVA745XorV1jr11RHxB9z4fqykm6xCUPX1aTJo/p2p-circuit") + as.autoRelayAddrs = func() []ma.Multiaddr { + return []ma.Multiaddr{relayAddr} + } + as.addrsFactory = func(addrs []ma.Multiaddr) []ma.Multiaddr { + for _, a := range addrs { + if a.Equal(relayAddr) { + return []ma.Multiaddr{ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1")} + } + } + return nil + } + require.Contains(t, as.Addrs(), ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1")) + require.NotContains(t, as.Addrs(), relayAddr) + }) + + t.Run("updates addresses on signaling", func(t *testing.T) { + as := getAddrService() + as.natmgr = nil + updateChan := make(chan struct{}) + a1 := ma.StringCast("/ip4/1.1.1.1/udp/1/quic-v1") + a2 := ma.StringCast("/ip4/1.1.1.1/tcp/1") + as.addrsFactory = func(addrs []ma.Multiaddr) []ma.Multiaddr { + select { + case <-updateChan: + return []ma.Multiaddr{a2} + default: + return []ma.Multiaddr{a1} + } + } + as.Start() + require.Contains(t, as.Addrs(), a1) + require.NotContains(t, as.Addrs(), a2) + close(updateChan) + as.SignalAddressChange() + select { + case <-as.AddrsUpdated(): + require.Contains(t, as.Addrs(), a2) + require.NotContains(t, as.Addrs(), a1) + case <-time.After(2 * time.Second): + t.Fatal("expected addrs to be updated") + } + }) +} diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index ac7cb9bd0f..de5b5aae18 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -90,6 +90,7 @@ type BasicHost struct { disableSignedPeerRecord bool signKey crypto.PrivKey + caBook peerstore.CertifiedAddrBook autoNATMx sync.RWMutex autoNat autonat.AutoNAT @@ -309,11 +310,12 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { if !ok { return nil, errors.New("peerstore should also be a certified address book") } + h.caBook = cab rec, err := h.makeSignedPeerRecord(h.addressService.Addrs()) if err != nil { return nil, fmt.Errorf("failed to create signed record for self: %w", err) } - if _, err := cab.ConsumePeerRecord(rec, peerstore.PermanentAddrTTL); err != nil { + if _, err := h.caBook.ConsumePeerRecord(rec, peerstore.PermanentAddrTTL); err != nil { return nil, fmt.Errorf("failed to persist signed record to peerstore: %w", err) } } @@ -398,56 +400,6 @@ func (h *BasicHost) newStreamHandler(s network.Stream) { handle(protoID, s) } -func (h *BasicHost) background() { - defer h.refCount.Done() - var lastAddrs []ma.Multiaddr - - // TODO: Deprecate this event and logic - emitAddrChange := func(currentAddrs []ma.Multiaddr, lastAddrs []ma.Multiaddr) { - changeEvt := h.makeUpdatedAddrEvent(lastAddrs, currentAddrs) - if changeEvt == nil { - return - } - // Our addresses have changed. - // store the signed peer record in the peer store. - if !h.disableSignedPeerRecord { - cabook, ok := peerstore.GetCertifiedAddrBook(h.Peerstore()) - if !ok { - log.Errorf("peerstore doesn't implement certified address book") - return - } - if _, err := cabook.ConsumePeerRecord(changeEvt.SignedPeerRecord, peerstore.PermanentAddrTTL); err != nil { - log.Errorf("failed to persist signed peer record in peer store, err=%s", err) - return - } - } - // update host addresses in the peer store - removedAddrs := make([]ma.Multiaddr, 0, len(changeEvt.Removed)) - for _, ua := range changeEvt.Removed { - removedAddrs = append(removedAddrs, ua.Address) - } - h.Peerstore().SetAddrs(h.ID(), currentAddrs, peerstore.PermanentAddrTTL) - h.Peerstore().SetAddrs(h.ID(), removedAddrs, 0) - - // emit addr change event - if err := h.emitters.evtLocalAddrsUpdated.Emit(*changeEvt); err != nil { - log.Warnf("error emitting event for updated addrs: %s", err) - } - } - - for { - curr := h.Addrs() - emitAddrChange(curr, lastAddrs) - lastAddrs = curr - - select { - case <-h.addressService.AddrsUpdated(): - case <-h.ctx.Done(): - return - } - } -} - func (h *BasicHost) makeUpdatedAddrEvent(prev, current []ma.Multiaddr) *event.EvtLocalAddressesUpdated { if prev == nil && current == nil { return nil @@ -515,6 +467,51 @@ func (h *BasicHost) makeSignedPeerRecord(addrs []ma.Multiaddr) (*record.Envelope return record.Seal(rec, h.signKey) } +func (h *BasicHost) background() { + defer h.refCount.Done() + var lastAddrs []ma.Multiaddr + + // TODO: Deprecate this event and logic + emitAddrChange := func(currentAddrs []ma.Multiaddr, lastAddrs []ma.Multiaddr) { + changeEvt := h.makeUpdatedAddrEvent(lastAddrs, currentAddrs) + if changeEvt == nil { + return + } + // Our addresses have changed. + // store the signed peer record in the peer store. + if !h.disableSignedPeerRecord { + if _, err := h.caBook.ConsumePeerRecord(changeEvt.SignedPeerRecord, peerstore.PermanentAddrTTL); err != nil { + log.Errorf("failed to persist signed peer record in peer store, err=%s", err) + return + } + } + // update host addresses in the peer store + removedAddrs := make([]ma.Multiaddr, 0, len(changeEvt.Removed)) + for _, ua := range changeEvt.Removed { + removedAddrs = append(removedAddrs, ua.Address) + } + h.Peerstore().SetAddrs(h.ID(), currentAddrs, peerstore.PermanentAddrTTL) + h.Peerstore().SetAddrs(h.ID(), removedAddrs, 0) + + // emit addr change event + if err := h.emitters.evtLocalAddrsUpdated.Emit(*changeEvt); err != nil { + log.Warnf("error emitting event for updated addrs: %s", err) + } + } + + for { + curr := h.Addrs() + emitAddrChange(curr, lastAddrs) + lastAddrs = curr + + select { + case <-h.addressService.AddrsUpdated(): + case <-h.ctx.Done(): + return + } + } +} + // ID returns the (local) peer.ID associated with this Host func (h *BasicHost) ID() peer.ID { return h.Network().LocalPeer()