diff --git a/core/services/registrysyncer/syncer.go b/core/services/registrysyncer/syncer.go index 811841fe390..7886e715f01 100644 --- a/core/services/registrysyncer/syncer.go +++ b/core/services/registrysyncer/syncer.go @@ -56,10 +56,9 @@ type registrySyncer struct { updateChan chan *LocalRegistry - wg sync.WaitGroup - lggr logger.Logger - mu sync.RWMutex - readerMu sync.RWMutex + wg sync.WaitGroup + lggr logger.Logger + mu sync.RWMutex } var _ services.Service = ®istrySyncer{} @@ -198,9 +197,6 @@ func (s *registrySyncer) updateStateLoop() { } func (s *registrySyncer) localRegistry(ctx context.Context) (*LocalRegistry, error) { - s.readerMu.RLock() - defer s.readerMu.RUnlock() - var caps []kcr.CapabilitiesRegistryCapabilityInfo err := s.reader.GetLatestValue(ctx, "CapabilitiesRegistry", "getCapabilities", primitives.Unconfirmed, nil, &caps) if err != nil { @@ -274,7 +270,6 @@ func (s *registrySyncer) Sync(ctx context.Context, isInitialSync bool) error { return nil } - s.readerMu.Lock() if s.reader == nil { reader, err := s.initReader(ctx, s.lggr, s.relayer, s.registryAddress) if err != nil { @@ -283,7 +278,6 @@ func (s *registrySyncer) Sync(ctx context.Context, isInitialSync bool) error { s.reader = reader } - s.readerMu.Unlock() var lr *LocalRegistry var err error diff --git a/core/services/registrysyncer/syncer_test.go b/core/services/registrysyncer/syncer_test.go index 821fae7d423..2c08a1cdde6 100644 --- a/core/services/registrysyncer/syncer_test.go +++ b/core/services/registrysyncer/syncer_test.go @@ -142,6 +142,38 @@ func (l *launcher) Launch(ctx context.Context, localRegistry *registrysyncer.Loc return nil } +type orm struct { + ormMock *syncerMocks.ORM + latestLocalRegistryCh chan struct{} + addLocalRegistryCh chan struct{} +} + +func newORM(t *testing.T) *orm { + t.Helper() + + return &orm{ + ormMock: syncerMocks.NewORM(t), + latestLocalRegistryCh: make(chan struct{}, 1), + addLocalRegistryCh: make(chan struct{}, 1), + } +} + +func (o *orm) Cleanup() { + close(o.latestLocalRegistryCh) + close(o.addLocalRegistryCh) +} + +func (o *orm) AddLocalRegistry(ctx context.Context, localRegistry registrysyncer.LocalRegistry) error { + o.addLocalRegistryCh <- struct{}{} + err := o.ormMock.AddLocalRegistry(ctx, localRegistry) + return err +} + +func (o *orm) LatestLocalRegistry(ctx context.Context) (*registrysyncer.LocalRegistry, error) { + o.latestLocalRegistryCh <- struct{}{} + return o.ormMock.LatestLocalRegistry(ctx) +} + func toPeerIDs(ids [][32]byte) []p2ptypes.PeerID { var pids []p2ptypes.PeerID for _, id := range ids { @@ -408,22 +440,35 @@ func TestSyncer_DBIntegration(t *testing.T) { require.NoError(t, err) factory := newContractReaderFactory(t, sim) - syncerORM := syncerMocks.NewORM(t) + syncerORM := newORM(t) + syncerORM.ormMock.On("LatestLocalRegistry", mock.Anything).Return(nil, fmt.Errorf("no state found")) + syncerORM.ormMock.On("AddLocalRegistry", mock.Anything, mock.Anything).Return(nil) syncer, err := newTestSyncer(logger.TestLogger(t), func() (p2ptypes.PeerID, error) { return p2ptypes.PeerID{}, nil }, factory, regAddress.Hex(), syncerORM) require.NoError(t, err) require.NoError(t, syncer.Start(ctx)) t.Cleanup(func() { + syncerORM.Cleanup() require.NoError(t, syncer.Close()) }) l := &launcher{} syncer.AddLauncher(l) - syncerORM.On("LatestLocalRegistry", mock.Anything).Return(nil, fmt.Errorf("no state found")) - syncerORM.On("AddLocalRegistry", mock.Anything, mock.Anything).Return(nil) - - err = syncer.Sync(ctx, false) // should store the data into the DB - require.NoError(t, err) + var latestLocalRegistryCalled, addLocalRegistryCalled bool + timeout := time.After(500 * time.Millisecond) + + for !latestLocalRegistryCalled || !addLocalRegistryCalled { + select { + case val := <-syncerORM.latestLocalRegistryCh: + assert.Equal(t, struct{}{}, val) + latestLocalRegistryCalled = true + case val := <-syncerORM.addLocalRegistryCh: + assert.Equal(t, struct{}{}, val) + addLocalRegistryCalled = true + case <-timeout: + t.Fatal("test timed out; channels did not received data") + } + } } func TestSyncer_LocalNode(t *testing.T) { @@ -509,7 +554,7 @@ func newTestSyncer( getPeerID func() (p2ptypes.PeerID, error), relayer registrysyncer.ContractReaderFactory, registryAddress string, - orm *syncerMocks.ORM, + orm *orm, ) (registrysyncer.RegistrySyncer, error) { rs, err := registrysyncer.New(lggr, getPeerID, relayer, registryAddress, orm) if err != nil {