From 7bbd5073ff79404c0b503918199d46d19c582434 Mon Sep 17 00:00:00 2001 From: acud <12988138+acud@users.noreply.github.com> Date: Fri, 9 Aug 2024 15:00:38 -0600 Subject: [PATCH] feat: add write-coalesing to atx handling --- activation/e2e/builds_atx_v2_test.go | 3 + activation/handler.go | 5 + activation/handler_test.go | 3 + activation/handler_v1.go | 163 ++++++++++++++++++++++++--- activation/handler_v1_test.go | 36 +++--- activation/metrics/metrics.go | 18 +++ checkpoint/recovery_test.go | 8 +- node/node.go | 4 + 8 files changed, 208 insertions(+), 32 deletions(-) diff --git a/activation/e2e/builds_atx_v2_test.go b/activation/e2e/builds_atx_v2_test.go index 2baf326060..1b6287d610 100644 --- a/activation/e2e/builds_atx_v2_test.go +++ b/activation/e2e/builds_atx_v2_test.go @@ -134,6 +134,9 @@ func TestBuilder_SwitchesToBuildV2(t *testing.T) { logger, activation.WithAtxVersions(atxVersions), ) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + go atxHdlr.Start(ctx) var previous *types.ActivationTx var publishedAtxs atomic.Uint32 diff --git a/activation/handler.go b/activation/handler.go index da99dd999d..fcfd6181eb 100644 --- a/activation/handler.go +++ b/activation/handler.go @@ -131,6 +131,7 @@ func NewHandler( beacon: beacon, tortoise: tortoise, signers: make(map[types.NodeID]*signing.EdSigner), + atxBatchResult: nil, }, v2: &HandlerV2{ @@ -169,6 +170,10 @@ func (h *Handler) Register(sig *signing.EdSigner) { h.v1.Register(sig) } +func (h *Handler) Start(ctx context.Context) { + h.v1.flushAtxLoop(ctx) +} + // HandleSyncedAtx handles atxs received by sync. func (h *Handler) HandleSyncedAtx(ctx context.Context, expHash types.Hash32, peer p2p.Peer, data []byte) error { _, err := h.handleAtx(ctx, expHash, peer, data) diff --git a/activation/handler_test.go b/activation/handler_test.go index fd7a668962..3ae942acb6 100644 --- a/activation/handler_test.go +++ b/activation/handler_test.go @@ -214,6 +214,9 @@ func newTestHandler(tb testing.TB, goldenATXID types.ATXID, opts ...HandlerOptio lg, opts..., ) + ctx, cancel := context.WithCancel(context.Background()) + go atxHdlr.Start(ctx) + tb.Cleanup(func() { cancel() }) return &testHandler{ Handler: atxHdlr, cdb: cdb, diff --git a/activation/handler_v1.go b/activation/handler_v1.go index 481c90df24..01cf0addda 100644 --- a/activation/handler_v1.go +++ b/activation/handler_v1.go @@ -13,6 +13,7 @@ import ( "go.uber.org/zap" "golang.org/x/exp/maps" + "github.com/spacemeshos/go-spacemesh/activation/metrics" "github.com/spacemeshos/go-spacemesh/activation/wire" "github.com/spacemeshos/go-spacemesh/atxsdata" "github.com/spacemeshos/go-spacemesh/codec" @@ -30,6 +31,8 @@ import ( "github.com/spacemeshos/go-spacemesh/system" ) +var sqlWriterSleep = 100 * time.Millisecond + type nipostValidatorV1 interface { InitialNIPostChallengeV1(challenge *wire.NIPostChallengeV1, atxs atxProvider, goldenATXID types.ATXID) error NIPostChallengeV1(challenge *wire.NIPostChallengeV1, previous *types.ActivationTx, nodeID types.NodeID) error @@ -83,6 +86,20 @@ type HandlerV1 struct { signerMtx sync.Mutex signers map[types.NodeID]*signing.EdSigner + + atxMu sync.Mutex + atxBatch []atxBatchItem + atxBatchResult *batchResult +} + +type batchResult struct { + doneC chan struct{} + err error +} + +type atxBatchItem struct { + atx *types.ActivationTx + watx *wire.ActivationTxV1 } func (h *HandlerV1) Register(sig *signing.EdSigner) { @@ -97,6 +114,75 @@ func (h *HandlerV1) Register(sig *signing.EdSigner) { h.signers[sig.NodeID()] = sig } +const poolItemMinSize = 1000 // minimum size of atx batch (to save on allocation) +var pool = &sync.Pool{ + New: func() any { + s := make([]atxBatchItem, 0, poolItemMinSize) + return &s + }, +} + +func getBatch() []atxBatchItem { + v := pool.Get().(*[]atxBatchItem) + return *v +} + +func putBatch(v []atxBatchItem) { + v = v[:0] + pool.Put(&v) +} + +func (h *HandlerV1) flushAtxLoop(ctx context.Context) { + t := time.NewTicker(sqlWriterSleep) + // initialize the first batch + h.atxMu.Lock() + h.atxBatchResult = &batchResult{doneC: make(chan struct{})} + h.atxMu.Unlock() + for { + select { + case <-ctx.Done(): + return + case <-t.C: + // copy-on-write + h.atxMu.Lock() + if len(h.atxBatch) == 0 { + h.atxMu.Unlock() + continue + } + batch := h.atxBatch // copy the existing slice + h.atxBatch = getBatch() // make a new one + res := h.atxBatchResult // copy the result type + h.atxBatchResult = &batchResult{doneC: make(chan struct{})} // make a new one + h.atxMu.Unlock() + metrics.FlushBatchSize.Add(float64(len(batch))) + + if err := h.cdb.WithTx(ctx, func(tx *sql.Tx) error { + var err error + for _, item := range batch { + err = atxs.Add(tx, item.atx, item.watx.Blob()) + if err != nil && !errors.Is(err, sql.ErrObjectExists) { + metrics.WriteBatchErrorsCount.Inc() + return fmt.Errorf("add atx to db: %w", err) + } + err = atxs.SetPost(tx, item.atx.ID(), item.watx.PrevATXID, 0, + item.atx.SmesherID, item.watx.NumUnits) + if err != nil && !errors.Is(err, sql.ErrObjectExists) { + metrics.WriteBatchErrorsCount.Inc() + return fmt.Errorf("set atx units: %w", err) + } + } + return nil + }); err != nil { + res.err = err + metrics.ErroredBatchCount.Inc() + h.logger.Error("flush atxs to db", zap.Error(err)) + } + putBatch(batch) + close(res.doneC) + } + } +} + func (h *HandlerV1) syntacticallyValidate(ctx context.Context, atx *wire.ActivationTxV1) error { if atx.NIPost == nil { return fmt.Errorf("nil nipost for atx %s", atx.ID()) @@ -489,37 +575,86 @@ func (h *HandlerV1) checkWrongPrevAtx( func (h *HandlerV1) checkMalicious( ctx context.Context, - tx *sql.Tx, + exec sql.Executor, watx *wire.ActivationTxV1, ) (*mwire.MalfeasanceProof, error) { - malicious, err := identities.IsMalicious(tx, watx.SmesherID) + malicious, err := identities.IsMalicious(exec, watx.SmesherID) if err != nil { return nil, fmt.Errorf("checking if node is malicious: %w", err) } if malicious { return nil, nil } - proof, err := h.checkDoublePublish(ctx, tx, watx) + proof, err := h.checkDoublePublish(ctx, exec, watx) if proof != nil || err != nil { return proof, err } - return h.checkWrongPrevAtx(ctx, tx, watx) + return h.checkWrongPrevAtx(ctx, exec, watx) } -// storeAtx stores an ATX and notifies subscribers of the ATXID. func (h *HandlerV1) storeAtx( ctx context.Context, atx *types.ActivationTx, watx *wire.ActivationTxV1, + deps bool, ) (*mwire.MalfeasanceProof, error) { - var proof *mwire.MalfeasanceProof + var ( + c chan struct{} + proof *mwire.MalfeasanceProof + br *batchResult + err error + ) + proof, err = h.checkMalicious(ctx, h.cdb, watx) + if err != nil { + return proof, fmt.Errorf("check malicious: %w", err) + } + if !deps { + h.atxMu.Lock() + h.atxBatch = append(h.atxBatch, atxBatchItem{atx: atx, watx: watx}) + br = h.atxBatchResult + c = br.doneC + h.atxMu.Unlock() + } else { + // we have deps, persist with sync flow + return proof, h.storeAtxSync(ctx, atx, watx, proof) + } + + select { + case <-c: + // wait for the batch the corresponds to the atx to be written + err = br.err + case <-ctx.Done(): + err = ctx.Err() + } + + atxs.AtxAdded(h.cdb, atx) + if proof != nil { + h.cdb.CacheMalfeasanceProof(atx.SmesherID, proof) + h.tortoise.OnMalfeasance(atx.SmesherID) + } + + added := h.cacheAtx(ctx, atx) + h.beacon.OnAtx(atx) + if added != nil { + h.tortoise.OnAtx(atx.TargetEpoch(), atx.ID(), added) + } + + h.logger.Debug("finished storing atx in epoch", + zap.Stringer("atx_id", atx.ID()), + zap.Uint32("epoch_id", atx.PublishEpoch.Uint32()), + ) + return proof, err +} + +// storeAtx stores an ATX and notifies subscribers of the ATXID. +func (h *HandlerV1) storeAtxSync( + ctx context.Context, + atx *types.ActivationTx, + watx *wire.ActivationTxV1, + proof *mwire.MalfeasanceProof, +) error { if err := h.cdb.WithTx(ctx, func(tx *sql.Tx) error { var err error - proof, err = h.checkMalicious(ctx, tx, watx) - if err != nil { - return fmt.Errorf("check malicious: %w", err) - } - err = atxs.Add(tx, atx, watx.Blob()) if err != nil && !errors.Is(err, sql.ErrObjectExists) { return fmt.Errorf("add atx to db: %w", err) @@ -531,7 +666,7 @@ func (h *HandlerV1) storeAtx( return nil }); err != nil { - return nil, fmt.Errorf("store atx: %w", err) + return fmt.Errorf("store atx: %w", err) } atxs.AtxAdded(h.cdb, atx) @@ -550,7 +685,7 @@ func (h *HandlerV1) storeAtx( zap.Stringer("atx_id", atx.ID()), zap.Uint32("epoch_id", atx.PublishEpoch.Uint32()), ) - return proof, nil + return nil } func (h *HandlerV1) processATX( @@ -627,7 +762,7 @@ func (h *HandlerV1) processATX( } atx.Weight = weight - proof, err = h.storeAtx(ctx, atx, watx) + proof, err = h.storeAtx(ctx, atx, watx, len(atxIDs) > 0) if err != nil { return nil, fmt.Errorf("cannot store atx %s: %w", atx.ShortString(), err) } diff --git a/activation/handler_v1_test.go b/activation/handler_v1_test.go index 5db33b1098..13e73b3695 100644 --- a/activation/handler_v1_test.go +++ b/activation/handler_v1_test.go @@ -35,7 +35,7 @@ func newV1TestHandler(tb testing.TB, goldenATXID types.ATXID) *v1TestHandler { lg := zaptest.NewLogger(tb) cdb := datastore.NewCachedDB(sql.InMemory(), lg) mocks := newTestHandlerMocks(tb, goldenATXID) - return &v1TestHandler{ + v1 := &v1TestHandler{ HandlerV1: &HandlerV1{ local: "localID", cdb: cdb, @@ -53,6 +53,10 @@ func newV1TestHandler(tb testing.TB, goldenATXID types.ATXID) *v1TestHandler { }, handlerMocks: mocks, } + ctx, cancel := context.WithCancel(context.Background()) + tb.Cleanup(cancel) + go v1.HandlerV1.flushAtxLoop(ctx) + return v1 } func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { @@ -549,7 +553,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) { return atx.(*types.ActivationTx).ID() == watx.ID() })) atxHdlr.mtortoise.EXPECT().OnAtx(watx.PublishEpoch+1, watx.ID(), gomock.Any()) - proof, err := atxHdlr.storeAtx(context.Background(), atx, watx) + proof, err := atxHdlr.storeAtx(context.Background(), atx, watx, false) require.NoError(t, err) require.Nil(t, proof) @@ -570,7 +574,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) { return atx.(*types.ActivationTx).ID() == watx.ID() })) atxHdlr.mtortoise.EXPECT().OnAtx(watx.PublishEpoch+1, watx.ID(), gomock.Any()) - proof, err := atxHdlr.storeAtx(context.Background(), atx, watx) + proof, err := atxHdlr.storeAtx(context.Background(), atx, watx, false) require.NoError(t, err) require.Nil(t, proof) @@ -578,7 +582,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) { return atx.(*types.ActivationTx).ID() == watx.ID() })) // Note: tortoise is not informed about the same ATX again - proof, err = atxHdlr.storeAtx(context.Background(), atx, watx) + proof, err = atxHdlr.storeAtx(context.Background(), atx, watx, false) require.NoError(t, err) require.Nil(t, proof) }) @@ -598,7 +602,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) { return atx.(*types.ActivationTx).ID() == watx.ID() })) atxHdlr.mtortoise.EXPECT().OnAtx(watx.PublishEpoch+1, watx.ID(), gomock.Any()) - proof, err := atxHdlr.storeAtx(context.Background(), atx, watx) + proof, err := atxHdlr.storeAtx(context.Background(), atx, watx, false) require.NoError(t, err) require.Nil(t, proof) @@ -619,7 +623,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) { return atx.(*types.ActivationTx).ID() == watx0.ID() })) atxHdlr.mtortoise.EXPECT().OnAtx(watx0.PublishEpoch+1, watx0.ID(), gomock.Any()) - proof, err := atxHdlr.storeAtx(context.Background(), atx0, watx0) + proof, err := atxHdlr.storeAtx(context.Background(), atx0, watx0, false) require.NoError(t, err) require.Nil(t, proof) @@ -633,7 +637,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) { })) atxHdlr.mtortoise.EXPECT().OnAtx(watx1.PublishEpoch+1, watx1.ID(), gomock.Any()) atxHdlr.mtortoise.EXPECT().OnMalfeasance(sig.NodeID()) - proof, err = atxHdlr.storeAtx(context.Background(), atx1, watx1) + proof, err = atxHdlr.storeAtx(context.Background(), atx1, watx1, false) require.NoError(t, err) require.NotNil(t, proof) require.Equal(t, mwire.MultipleATXs, proof.Proof.Type) @@ -660,7 +664,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) { return atx.(*types.ActivationTx).ID() == watx0.ID() })) atxHdlr.mtortoise.EXPECT().OnAtx(watx0.PublishEpoch+1, watx0.ID(), gomock.Any()) - proof, err := atxHdlr.storeAtx(context.Background(), atx0, watx0) + proof, err := atxHdlr.storeAtx(context.Background(), atx0, watx0, false) require.NoError(t, err) require.Nil(t, proof) @@ -669,7 +673,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) { watx1.Sign(sig) atx1 := toAtx(t, watx1) - proof, err = atxHdlr.storeAtx(context.Background(), atx1, watx1) + proof, err = atxHdlr.storeAtx(context.Background(), atx1, watx1, false) require.ErrorContains(t, err, fmt.Sprintf("%s already published an ATX", sig.NodeID().ShortString()), @@ -692,7 +696,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) { return atx.(*types.ActivationTx).ID() == initialATX.ID() })) atxHdlr.mtortoise.EXPECT().OnAtx(initialATX.PublishEpoch+1, initialATX.ID(), gomock.Any()) - proof, err := atxHdlr.storeAtx(context.Background(), wInitialATX, initialATX) + proof, err := atxHdlr.storeAtx(context.Background(), wInitialATX, initialATX, false) require.NoError(t, err) require.Nil(t, proof) @@ -705,7 +709,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) { return atx.(*types.ActivationTx).ID() == watx1.ID() })) atxHdlr.mtortoise.EXPECT().OnAtx(watx1.PublishEpoch+1, watx1.ID(), gomock.Any()) - proof, err = atxHdlr.storeAtx(context.Background(), atx1, watx1) + proof, err = atxHdlr.storeAtx(context.Background(), atx1, watx1, false) require.NoError(t, err) require.Nil(t, proof) @@ -717,7 +721,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) { return atx.(*types.ActivationTx).ID() == watx2.ID() })) atxHdlr.mtortoise.EXPECT().OnAtx(watx2.PublishEpoch+1, watx2.ID(), gomock.Any()) - proof, err = atxHdlr.storeAtx(context.Background(), atx2, watx2) + proof, err = atxHdlr.storeAtx(context.Background(), atx2, watx2, false) require.NoError(t, err) require.Nil(t, proof) @@ -732,7 +736,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) { })) atxHdlr.mtortoise.EXPECT().OnAtx(watx3.PublishEpoch+1, watx3.ID(), gomock.Any()) atxHdlr.mtortoise.EXPECT().OnMalfeasance(sig.NodeID()) - proof, err = atxHdlr.storeAtx(context.Background(), atx3, watx3) + proof, err = atxHdlr.storeAtx(context.Background(), atx3, watx3, false) require.NoError(t, err) require.NotNil(t, proof) require.Equal(t, mwire.InvalidPrevATX, proof.Proof.Type) @@ -756,7 +760,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) { return atx.(*types.ActivationTx).ID() == wInitialATX.ID() })) atxHdlr.mtortoise.EXPECT().OnAtx(wInitialATX.PublishEpoch+1, wInitialATX.ID(), gomock.Any()) - proof, err := atxHdlr.storeAtx(context.Background(), initialAtx, wInitialATX) + proof, err := atxHdlr.storeAtx(context.Background(), initialAtx, wInitialATX, false) require.NoError(t, err) require.Nil(t, proof) @@ -769,7 +773,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) { return atx.(*types.ActivationTx).ID() == watx1.ID() })) atxHdlr.mtortoise.EXPECT().OnAtx(watx1.PublishEpoch+1, watx1.ID(), gomock.Any()) - proof, err = atxHdlr.storeAtx(context.Background(), atx1, watx1) + proof, err = atxHdlr.storeAtx(context.Background(), atx1, watx1, false) require.NoError(t, err) require.Nil(t, proof) @@ -779,7 +783,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) { watx2.Sign(sig) atx2 := toAtx(t, watx2) - proof, err = atxHdlr.storeAtx(context.Background(), atx2, watx2) + proof, err = atxHdlr.storeAtx(context.Background(), atx2, watx2, false) require.ErrorContains(t, err, fmt.Sprintf("%s referenced incorrect previous ATX", sig.NodeID().ShortString()), diff --git a/activation/metrics/metrics.go b/activation/metrics/metrics.go index b1d68d0283..90bb5b4105 100644 --- a/activation/metrics/metrics.go +++ b/activation/metrics/metrics.go @@ -50,3 +50,21 @@ var PostVerificationLatency = metrics.NewHistogramWithBuckets( []string{}, prometheus.ExponentialBuckets(1, 2, 20), ).WithLabelValues() + +var WriteBatchErrorsCount = prometheus.NewCounter(metrics.NewCounterOpts( + namespace, + "write_batch_errors", + "number of errors when writing a batch", +)) + +var ErroredBatchCount = prometheus.NewCounter(metrics.NewCounterOpts( + namespace, + "errored_batch", + "number of batches that errored", +)) + +var FlushBatchSize = prometheus.NewCounter(metrics.NewCounterOpts( + namespace, + "flush_batch_size", + "size of flushed batch", +)) diff --git a/checkpoint/recovery_test.go b/checkpoint/recovery_test.go index d67a259225..9beaf38dfa 100644 --- a/checkpoint/recovery_test.go +++ b/checkpoint/recovery_test.go @@ -168,7 +168,7 @@ func TestRecover(t *testing.T) { } bsdir := filepath.Join(cfg.DataDir, bootstrap.DirName) require.NoError(t, fs.MkdirAll(bsdir, 0o700)) - db := sql.InMemory() + db := sql.InMemoryTest(t) localDB := localsql.InMemory() data, err := checkpoint.RecoverWithDb(context.Background(), zaptest.NewLogger(t), db, localDB, fs, cfg) if tc.expErr != nil { @@ -209,7 +209,7 @@ func TestRecover_SameRecoveryInfo(t *testing.T) { } bsdir := filepath.Join(cfg.DataDir, bootstrap.DirName) require.NoError(t, fs.MkdirAll(bsdir, 0o700)) - db := sql.InMemory() + db := sql.InMemoryTest(t) localDB := localsql.InMemory() types.SetEffectiveGenesis(0) require.NoError(t, recovery.SetCheckpoint(db, types.LayerID(recoverLayer))) @@ -267,6 +267,10 @@ func validateAndPreserveData( mtrtl, lg, ) + ctx, cancel := context.WithCancel(context.Background()) + tb.Cleanup(cancel) + go atxHandler.Start(ctx) + mfetch.EXPECT().GetAtxs(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() for _, dep := range deps { var atx wire.ActivationTxV1 diff --git a/node/node.go b/node/node.go index cb00f9998b..27fdbd428c 100644 --- a/node/node.go +++ b/node/node.go @@ -746,6 +746,10 @@ func (app *App) initServices(ctx context.Context) error { for _, sig := range app.signers { atxHandler.Register(sig) } + app.eg.Go(func() error { + atxHandler.Start(ctx) + return nil + }) // we can't have an epoch offset which is greater/equal than the number of layers in an epoch