Skip to content

Commit

Permalink
Resolve data races when closing services (#10535)
Browse files Browse the repository at this point in the history
* avoid logging if context is done

* add wait group to block Close until all goroutines return

* block Close() for recoverer and registry

* block Close() of states store

* states store

* rename threads -> threadsWG

* wg.Add before spawning a goroutine

* remove redundant assignment

the channel was created in the constructor

* added thread control helper

* refactor automation services to use common infra for service logic

* handle errors while spawning goroutines

* block subscriber

* remove thread control (was extracted to #10560)

* align thread control interface

* extract function that requires lock

* recoverer: spawn several goroutines for background work

* remove redundant log

* remove redundant check

* removed unused cancel

* reset flush ticker

* timer to ticker

* fix test
  • Loading branch information
amirylm authored Sep 11, 2023
1 parent 36db6a4 commit 09d1534
Show file tree
Hide file tree
Showing 8 changed files with 237 additions and 287 deletions.
109 changes: 54 additions & 55 deletions core/services/ocr2/plugins/ocr2keeper/evm21/block_subscriber.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@ const (
blockHistorySize = int64(256)
)

var (
BlockSubscriberServiceName = "BlockSubscriber"
)

type BlockSubscriber struct {
sync utils.StartStopOnce
utils.StartStopOnce
threadCtrl utils.ThreadControl

mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
hb httypes.HeadBroadcaster
lp logpoller.LogPoller
headC chan *evmtypes.Head
Expand All @@ -53,6 +57,7 @@ var _ ocr2keepers.BlockSubscriber = &BlockSubscriber{}

func NewBlockSubscriber(hb httypes.HeadBroadcaster, lp logpoller.LogPoller, lggr logger.Logger) *BlockSubscriber {
return &BlockSubscriber{
threadCtrl: utils.NewThreadControl(),
hb: hb,
lp: lp,
headC: make(chan *evmtypes.Head, channelSize),
Expand Down Expand Up @@ -81,8 +86,8 @@ func (bs *BlockSubscriber) getBlockRange(ctx context.Context) ([]uint64, error)
return blocks, nil
}

func (bs *BlockSubscriber) initializeBlocks(blocks []uint64) error {
logpollerBlocks, err := bs.lp.GetBlocksRange(bs.ctx, blocks, pg.WithParentCtx(bs.ctx))
func (bs *BlockSubscriber) initializeBlocks(ctx context.Context, blocks []uint64) error {
logpollerBlocks, err := bs.lp.GetBlocksRange(ctx, blocks)
if err != nil {
return err
}
Expand Down Expand Up @@ -127,67 +132,61 @@ func (bs *BlockSubscriber) cleanup() {
bs.lggr.Infof("lastClearedBlock is set to %d", bs.lastClearedBlock)
}

func (bs *BlockSubscriber) Start(_ context.Context) error {
bs.lggr.Info("block subscriber started.")
return bs.sync.StartOnce("BlockSubscriber", func() error {
bs.mu.Lock()
defer bs.mu.Unlock()
bs.ctx, bs.cancel = context.WithCancel(context.Background())
// initialize the blocks map with the recent blockSize blocks
blocks, err := bs.getBlockRange(bs.ctx)
if err != nil {
bs.lggr.Errorf("failed to get block range", err)
}
err = bs.initializeBlocks(blocks)
if err != nil {
bs.lggr.Errorf("failed to get log poller blocks", err)
}

_, bs.unsubscribe = bs.hb.Subscribe(&headWrapper{headC: bs.headC, lggr: bs.lggr})
func (bs *BlockSubscriber) initialize(ctx context.Context) {
bs.mu.Lock()
defer bs.mu.Unlock()
// initialize the blocks map with the recent blockSize blocks
blocks, err := bs.getBlockRange(ctx)
if err != nil {
bs.lggr.Errorf("failed to get block range", err)
}
err = bs.initializeBlocks(ctx, blocks)
if err != nil {
bs.lggr.Errorf("failed to get log poller blocks", err)
}
_, bs.unsubscribe = bs.hb.Subscribe(&headWrapper{headC: bs.headC, lggr: bs.lggr})
}

func (bs *BlockSubscriber) Start(ctx context.Context) error {
return bs.StartOnce(BlockSubscriberServiceName, func() error {
bs.lggr.Info("block subscriber started.")
bs.initialize(ctx)
// poll from head broadcaster channel and push to subscribers
{
go func(ctx context.Context) {
for {
select {
case h := <-bs.headC:
if h != nil {
bs.processHead(h)
}
case <-ctx.Done():
return
bs.threadCtrl.Go(func(ctx context.Context) {
for {
select {
case h := <-bs.headC:
if h != nil {
bs.processHead(h)
}
case <-ctx.Done():
return
}
}(bs.ctx)
}

// clean up block maps
{
go func(ctx context.Context) {
ticker := time.NewTicker(cleanUpInterval)
for {
select {
case <-ticker.C:
bs.cleanup()
case <-ctx.Done():
ticker.Stop()
return
}
}
})
// cleanup old blocks
bs.threadCtrl.Go(func(ctx context.Context) {
ticker := time.NewTicker(cleanUpInterval)
defer ticker.Stop()

for {
select {
case <-ticker.C:
bs.cleanup()
case <-ctx.Done():
return
}
}(bs.ctx)
}
}
})

return nil
})
}

func (bs *BlockSubscriber) Close() error {
bs.lggr.Info("stop block subscriber")
return bs.sync.StopOnce("BlockSubscriber", func() error {
bs.mu.Lock()
defer bs.mu.Unlock()

bs.cancel()
return bs.StopOnce(BlockSubscriberServiceName, func() error {
bs.lggr.Info("stop block subscriber")
bs.threadCtrl.Close()
bs.unsubscribe()
return nil
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func TestBlockSubscriber_InitializeBlocks(t *testing.T) {
bs := NewBlockSubscriber(hb, lp, lggr)
bs.blockHistorySize = historySize
bs.blockSize = blockSize
err := bs.initializeBlocks(tc.Blocks)
err := bs.initializeBlocks(testutils.Context(t), tc.Blocks)

if tc.Error != nil {
assert.Equal(t, tc.Error.Error(), err.Error())
Expand Down
80 changes: 30 additions & 50 deletions core/services/ocr2/plugins/ocr2keeper/evm21/logprovider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ import (
"github.com/smartcontractkit/chainlink/v2/core/logger"
"github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ocr2keeper/evm21/core"
"github.com/smartcontractkit/chainlink/v2/core/services/pg"
"github.com/smartcontractkit/chainlink/v2/core/utils"
)

var (
LogProviderServiceName = "LogEventProvider"

ErrHeadNotAvailable = fmt.Errorf("head not available")
ErrBlockLimitExceeded = fmt.Errorf("block limit exceeded")

Expand Down Expand Up @@ -78,9 +81,10 @@ var _ LogEventProviderTest = &logEventProvider{}

// logEventProvider manages log filters for upkeeps and enables to read the log events.
type logEventProvider struct {
lggr logger.Logger
utils.StartStopOnce
threadCtrl utils.ThreadControl

cancel context.CancelFunc
lggr logger.Logger

poller logpoller.LogPoller

Expand All @@ -99,8 +103,9 @@ type logEventProvider struct {

func NewLogProvider(lggr logger.Logger, poller logpoller.LogPoller, packer LogDataPacker, filterStore UpkeepFilterStore, opts LogTriggersOptions) *logEventProvider {
return &logEventProvider{
packer: packer,
threadCtrl: utils.NewThreadControl(),
lggr: lggr.Named("KeepersRegistry.LogEventProvider"),
packer: packer,
buffer: newLogEventBuffer(lggr, int(opts.LookbackBlocks), maxLogsPerBlock, maxLogsPerUpkeepInBlock),
poller: poller,
opts: opts,
Expand All @@ -109,69 +114,44 @@ func NewLogProvider(lggr logger.Logger, poller logpoller.LogPoller, packer LogDa
}

func (p *logEventProvider) Start(context.Context) error {
ctx, cancel := context.WithCancel(context.Background())
return p.StartOnce(LogProviderServiceName, func() error {

p.lock.Lock()
if p.cancel != nil {
p.lock.Unlock()
cancel() // Cancel the created context
return errors.New("already started")
}
p.cancel = cancel
p.lock.Unlock()
readQ := make(chan []*big.Int, readJobQueueSize)

readQ := make(chan []*big.Int, readJobQueueSize)
p.lggr.Infow("starting log event provider", "readInterval", p.opts.ReadInterval, "readMaxBatchSize", readMaxBatchSize, "readers", readerThreads)

p.lggr.Infow("starting log event provider", "readInterval", p.opts.ReadInterval, "readMaxBatchSize", readMaxBatchSize, "readers", readerThreads)
for i := 0; i < readerThreads; i++ {
p.threadCtrl.Go(func(ctx context.Context) {
p.startReader(ctx, readQ)
})
}

{ // start readers
go func(ctx context.Context) {
for i := 0; i < readerThreads; i++ {
go p.startReader(ctx, readQ)
}
}(ctx)
}
p.threadCtrl.Go(func(ctx context.Context) {
lggr := p.lggr.With("where", "scheduler")

{ // start scheduler
lggr := p.lggr.With("where", "scheduler")
go func(ctx context.Context) {
err := p.scheduleReadJobs(ctx, func(ids []*big.Int) {
p.scheduleReadJobs(ctx, func(ids []*big.Int) {
select {
case readQ <- ids:
case <-ctx.Done():
default:
lggr.Warnw("readQ is full, dropping ids", "ids", ids)
}
})
// if the context was canceled, we don't need to log the error
if ctx.Err() != nil {
return
}
if err != nil {
lggr.Warnw("stopped scheduling read jobs with error", "err", err)
}
lggr.Debug("stopped scheduling read jobs")
}(ctx)
}
})

return nil
return nil
})
}

func (p *logEventProvider) Close() error {
p.lock.Lock()
defer p.lock.Unlock()

if cancel := p.cancel; cancel != nil {
p.cancel = nil
cancel()
} else {
return errors.New("already stopped")
}
return nil
return p.StopOnce(LogProviderServiceName, func() error {
p.threadCtrl.Close()
return nil
})
}

func (p *logEventProvider) Name() string {
return p.lggr.Name()
func (p *logEventProvider) HealthReport() map[string]error {
return map[string]error{LogProviderServiceName: p.Healthy()}
}

func (p *logEventProvider) GetLatestPayloads(ctx context.Context) ([]ocr2keepers.UpkeepPayload, error) {
Expand Down Expand Up @@ -237,7 +217,7 @@ func (p *logEventProvider) CurrentPartitionIdx() uint64 {
}

// scheduleReadJobs starts a scheduler that pushed ids to readQ for reading logs in the background.
func (p *logEventProvider) scheduleReadJobs(pctx context.Context, execute func([]*big.Int)) error {
func (p *logEventProvider) scheduleReadJobs(pctx context.Context, execute func([]*big.Int)) {
ctx, cancel := context.WithCancel(pctx)
defer cancel()

Expand Down Expand Up @@ -265,7 +245,7 @@ func (p *logEventProvider) scheduleReadJobs(pctx context.Context, execute func([
partitionIdx++
atomic.StoreUint64(&p.currentPartitionIdx, partitionIdx)
case <-ctx.Done():
return ctx.Err()
return
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ func TestLogEventProvider_ScheduleReadJobs(t *testing.T) {
reads := make(chan []*big.Int, 100)

go func(ctx context.Context) {
_ = p.scheduleReadJobs(ctx, func(ids []*big.Int) {
p.scheduleReadJobs(ctx, func(ids []*big.Int) {
select {
case reads <- ids:
default:
Expand Down
Loading

0 comments on commit 09d1534

Please sign in to comment.