From dc5fcaf4733695c36b312fe597e7acfca68cf4e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bartosz=20R=C3=B3=C5=BCa=C5=84ski?= Date: Tue, 6 Aug 2024 08:10:45 +0000 Subject: [PATCH] detect double atx-merging malfeasance (#6135) ## Motivation Publishing two merged ATXs in the same epoch is forbidden, even if the sets of IDs participating in both are disjoint. For example, given a married set of IDs (A, B, C, D), it's not allowed to publish two merged ATXs, one with IDs (A, B) and the second with (C, D). The C and D can publish **separately**. --- activation/e2e/checkpoint_merged_test.go | 6 + activation/handler_v2.go | 37 ++++++ activation/handler_v2_test.go | 140 ++++++++++++++++++----- activation/wire/malfeasance.go | 1 + checkpoint/recovery.go | 4 + checkpoint/recovery_test.go | 3 +- checkpoint/runner.go | 5 + checkpoint/runner_test.go | 37 ++++++ common/types/activation.go | 7 +- common/types/checkpoint.go | 1 + sql/atxs/atxs.go | 64 ++++++++--- sql/atxs/atxs_test.go | 52 +++++++++ sql/migrations/state/0020_atx_merge.sql | 1 + 13 files changed, 315 insertions(+), 43 deletions(-) diff --git a/activation/e2e/checkpoint_merged_test.go b/activation/e2e/checkpoint_merged_test.go index cc5cfb00bd..3984d926f8 100644 --- a/activation/e2e/checkpoint_merged_test.go +++ b/activation/e2e/checkpoint_merged_test.go @@ -273,6 +273,12 @@ func Test_CheckpointAfterMerge(t *testing.T) { require.Equal(t, i, marriage.Index) } + checkpointedMerged, err := atxs.Get(newDB, mergedATX.ID()) + require.NoError(t, err) + require.True(t, checkpointedMerged.Golden()) + require.NotNil(t, checkpointedMerged.MarriageATX) + require.Equal(t, marriageATX.ID(), *checkpointedMerged.MarriageATX) + // 4. Spawn new ATX handler and builder using the new DB poetDb = activation.NewPoetDb(newDB, logger.Named("poetDb")) cdb = datastore.NewCachedDB(newDB, logger) diff --git a/activation/handler_v2.go b/activation/handler_v2.go index 176b25fd0d..155b5db6d1 100644 --- a/activation/handler_v2.go +++ b/activation/handler_v2.go @@ -117,6 +117,7 @@ func (h *HandlerV2) processATX( atx := &types.ActivationTx{ PublishEpoch: watx.PublishEpoch, + MarriageATX: watx.MarriageATX, Coinbase: watx.Coinbase, BaseTickHeight: baseTickHeight, NumUnits: parts.effectiveUnits, @@ -684,6 +685,14 @@ func (h *HandlerV2) checkMalicious( return nil } + malicious, err = h.checkDoubleMerge(ctx, tx, watx) + if err != nil { + return fmt.Errorf("checking double merge: %w", err) + } + if malicious { + return nil + } + // TODO(mafa): contextual validation: // 1. check double-publish = ID contributed post to two ATXs in the same epoch // 2. check previous ATX @@ -746,6 +755,34 @@ func (h *HandlerV2) checkDoublePost( return false, nil } +func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, watx *wire.ActivationTxV2) (bool, error) { + if watx.MarriageATX == nil { + return false, nil + } + ids, err := atxs.MergeConflict(tx, *watx.MarriageATX, watx.PublishEpoch) + switch { + case errors.Is(err, sql.ErrNotFound): + return false, nil + case err != nil: + return false, fmt.Errorf("searching for ATXs with the same marriage ATX: %w", err) + } + otherIndex := slices.IndexFunc(ids, func(id types.ATXID) bool { return id != watx.ID() }) + other := ids[otherIndex] + + h.logger.Debug("second merged ATX for single marriage - creating malfeasance proof", + zap.Stringer("marriage_atx", *watx.MarriageATX), + zap.Stringer("atx", watx.ID()), + zap.Stringer("other_atx", other), + zap.Stringer("smesher_id", watx.SmesherID), + ) + + // TODO(mafa): finish proof + proof := &wire.ATXProof{ + ProofType: wire.DoubleMerge, + } + return true, h.malPublisher.Publish(ctx, watx.SmesherID, proof) +} + // Store an ATX in the DB. func (h *HandlerV2) storeAtx( ctx context.Context, diff --git a/activation/handler_v2_test.go b/activation/handler_v2_test.go index db7967817a..06f52f9b69 100644 --- a/activation/handler_v2_test.go +++ b/activation/handler_v2_test.go @@ -614,18 +614,16 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) { func marryIDs( t testing.TB, atxHandler *v2TestHandler, - sig *signing.EdSigner, + signers []*signing.EdSigner, golden types.ATXID, - num int, ) (marriage *wire.ActivationTxV2, other []*wire.ActivationTxV2) { + sig := signers[0] mATX := newInitialATXv2(t, golden) mATX.Marriages = []wire.MarriageCertificate{{ Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), }} - for range num { - signer, err := signing.NewEdSigner() - require.NoError(t, err) + for _, signer := range signers[1:] { atx := atxHandler.createAndProcessInitial(t, signer) other = append(other, atx) mATX.Marriages = append(mATX.Marriages, wire.MarriageCertificate{ @@ -644,20 +642,27 @@ func marryIDs( func TestHandlerV2_ProcessMergedATX(t *testing.T) { t.Parallel() - golden := types.RandomATXID() - sig, err := signing.NewEdSigner() - require.NoError(t, err) + var ( + golden = types.RandomATXID() + signers []*signing.EdSigner + equivocationSet []types.NodeID + ) + for range 4 { + sig, err := signing.NewEdSigner() + require.NoError(t, err) + signers = append(signers, sig) + equivocationSet = append(equivocationSet, sig.NodeID()) + } + sig := signers[0] t.Run("happy case", func(t *testing.T) { atxHandler := newV2TestHandler(t, golden) // Marry IDs - mATX, otherATXs := marryIDs(t, atxHandler, sig, golden, 2) + mATX, otherATXs := marryIDs(t, atxHandler, signers, golden) previousATXs := []types.ATXID{mATX.ID()} - equivocationSet := []types.NodeID{sig.NodeID()} for _, atx := range otherATXs { previousATXs = append(previousATXs, atx.ID()) - equivocationSet = append(equivocationSet, atx.SmesherID) } // Process a merged ATX @@ -694,12 +699,10 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { atxHandler.tickSize = tickSize // Marry IDs - mATX, otherATXs := marryIDs(t, atxHandler, sig, golden, 4) + mATX, otherATXs := marryIDs(t, atxHandler, signers, golden) previousATXs := []types.ATXID{mATX.ID()} - equivocationSet := []types.NodeID{sig.NodeID()} for _, atx := range otherATXs { previousATXs = append(previousATXs, atx.ID()) - equivocationSet = append(equivocationSet, atx.SmesherID) } // Process a merged ATX @@ -765,12 +768,10 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { atxHandler := newV2TestHandler(t, golden) // Marry IDs - mATX, otherATXs := marryIDs(t, atxHandler, sig, golden, 2) + mATX, otherATXs := marryIDs(t, atxHandler, signers, golden) previousATXs := []types.ATXID{} - equivocationSet := []types.NodeID{sig.NodeID()} for _, atx := range otherATXs { previousATXs = append(previousATXs, atx.ID()) - equivocationSet = append(equivocationSet, atx.SmesherID) } // Process a merged ATX @@ -802,12 +803,10 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { atxHandler := newV2TestHandler(t, golden) // Marry IDs - mATX, otherATXs := marryIDs(t, atxHandler, sig, golden, 1) + mATX, otherATXs := marryIDs(t, atxHandler, signers[:2], golden) previousATXs := []types.ATXID{mATX.ID()} - equivocationSet := []types.NodeID{sig.NodeID()} for _, atx := range otherATXs { previousATXs = append(previousATXs, atx.ID()) - equivocationSet = append(equivocationSet, atx.SmesherID) } // Process a merged ATX @@ -836,12 +835,10 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { atxHandler := newV2TestHandler(t, golden) // Marry IDs - mATX, otherATXs := marryIDs(t, atxHandler, sig, golden, 1) + mATX, otherATXs := marryIDs(t, atxHandler, signers[:2], golden) previousATXs := []types.ATXID{mATX.ID()} - equivocationSet := []types.NodeID{sig.NodeID()} for _, atx := range otherATXs { previousATXs = append(previousATXs, atx.ID()) - equivocationSet = append(equivocationSet, atx.SmesherID) } // Process a merged ATX @@ -868,11 +865,7 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { atxHandler := newV2TestHandler(t, golden) // Marry IDs - mATX, otherATXs := marryIDs(t, atxHandler, sig, golden, 1) - equivocationSet := []types.NodeID{sig.NodeID()} - for _, atx := range otherATXs { - equivocationSet = append(equivocationSet, atx.SmesherID) - } + mATX, _ := marryIDs(t, atxHandler, signers, golden) prev := atxs.CheckpointAtx{ Epoch: mATX.PublishEpoch + 1, @@ -932,6 +925,97 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { err = atxHandler.processATX(context.Background(), "", merged, time.Now()) require.ErrorIs(t, err, pubsub.ErrValidationReject) }) + t.Run("publishing two merged ATXs from one marriage set is malfeasance", func(t *testing.T) { + atxHandler := newV2TestHandler(t, golden) + + // Marry 4 IDs + mATX, otherATXs := marryIDs(t, atxHandler, signers, golden) + previousATXs := []types.ATXID{mATX.ID()} + for _, atx := range otherATXs { + previousATXs = append(previousATXs, atx.ID()) + } + + // Process a merged ATX for 2 IDs + merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) + merged.NiPosts[0].Posts = []wire.SubPostV2{} + for i := range equivocationSet[:2] { + post := wire.SubPostV2{ + MarriageIndex: uint32(i), + PrevATXIndex: uint32(i), + NumUnits: 4, + } + merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + } + + mATXID := mATX.ID() + + merged.MarriageATX = &mATXID + merged.PreviousATXs = []types.ATXID{mATX.ID(), otherATXs[0].ID()} + merged.Sign(sig) + + atxHandler.expectMergedAtxV2(merged, equivocationSet, []uint64{100}) + err := atxHandler.processATX(context.Background(), "", merged, time.Now()) + require.NoError(t, err) + + // Process a second merged ATX for the same equivocation set, but different IDs + merged = newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) + merged.NiPosts[0].Posts = []wire.SubPostV2{} + for i := range equivocationSet[:2] { + post := wire.SubPostV2{ + MarriageIndex: uint32(i + 2), + PrevATXIndex: uint32(i), + NumUnits: 4, + } + merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + } + + mATXID = mATX.ID() + merged.MarriageATX = &mATXID + merged.PreviousATXs = []types.ATXID{otherATXs[1].ID(), otherATXs[2].ID()} + merged.Sign(signers[2]) + + atxHandler.expectMergedAtxV2(merged, equivocationSet, []uint64{100}) + atxHandler.mMalPublish.EXPECT().Publish(gomock.Any(), merged.SmesherID, gomock.Any()) + err = atxHandler.processATX(context.Background(), "", merged, time.Now()) + require.NoError(t, err) + }) + t.Run("publishing two merged ATXs (one checkpointed)", func(t *testing.T) { + atxHandler := newV2TestHandler(t, golden) + + mATX, otherATXs := marryIDs(t, atxHandler, signers, golden) + mATXID := mATX.ID() + + // Insert checkpointed merged ATX + checkpointedATX := &atxs.CheckpointAtx{ + Epoch: mATX.PublishEpoch + 2, + ID: types.RandomATXID(), + SmesherID: signers[0].NodeID(), + MarriageATX: &mATXID, + } + require.NoError(t, atxs.AddCheckpointed(atxHandler.cdb, checkpointedATX)) + + // create and process another merged ATX + merged := newSoloATXv2(t, checkpointedATX.Epoch, mATX.ID(), golden) + merged.NiPosts[0].Posts = []wire.SubPostV2{} + for i := range equivocationSet[2:] { + post := wire.SubPostV2{ + MarriageIndex: uint32(i + 2), + PrevATXIndex: uint32(i), + NumUnits: 4, + } + merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + } + + merged.MarriageATX = &mATXID + merged.PreviousATXs = []types.ATXID{otherATXs[1].ID(), otherATXs[2].ID()} + merged.Sign(signers[2]) + atxHandler.expectMergedAtxV2(merged, equivocationSet, []uint64{100}) + // TODO: this could be syntactically validated as all nodes in the network + // should already have the checkpointed merged ATX. + atxHandler.mMalPublish.EXPECT().Publish(gomock.Any(), merged.SmesherID, gomock.Any()) + err := atxHandler.processATX(context.Background(), "", merged, time.Now()) + require.NoError(t, err) + }) } func TestCollectDeps_AtxV2(t *testing.T) { diff --git a/activation/wire/malfeasance.go b/activation/wire/malfeasance.go index c00bbcd984..d8e60a4127 100644 --- a/activation/wire/malfeasance.go +++ b/activation/wire/malfeasance.go @@ -12,6 +12,7 @@ type ProofType byte const ( DoublePublish ProofType = iota + 1 DoubleMarry + DoubleMerge InvalidPost ) diff --git a/checkpoint/recovery.go b/checkpoint/recovery.go index 97a77d5468..31df1cdff1 100644 --- a/checkpoint/recovery.go +++ b/checkpoint/recovery.go @@ -356,6 +356,10 @@ func checkpointData(fs afero.Fs, file string, newGenesis types.LayerID) (*recove cAtx.ID = types.ATXID(types.BytesToHash(atx.ID)) cAtx.Epoch = types.EpochID(atx.Epoch) cAtx.CommitmentATX = types.ATXID(types.BytesToHash(atx.CommitmentAtx)) + if len(atx.MarriageAtx) == 32 { + marriageATXID := types.ATXID(atx.MarriageAtx) + cAtx.MarriageATX = &marriageATXID + } cAtx.SmesherID = types.BytesToNodeID(atx.PublicKey) cAtx.NumUnits = atx.NumUnits cAtx.VRFNonce = types.VRFPostIndex(atx.VrfNonce) diff --git a/checkpoint/recovery_test.go b/checkpoint/recovery_test.go index b21f9e9d06..d67a259225 100644 --- a/checkpoint/recovery_test.go +++ b/checkpoint/recovery_test.go @@ -908,7 +908,8 @@ func TestRecover_OwnAtxInCheckpoint(t *testing.T) { require.NoError(t, err) atxid, err := hex.DecodeString("98e47278c1f58acfd2b670a730f28898f74eb140482a07b91ff81f9ff0b7d9f4") require.NoError(t, err) - atx := newAtx(types.ATXID(atxid), types.EmptyATXID, nil, 3, 1, 0, nid) + atx := &types.ActivationTx{SmesherID: types.NodeID(nid)} + atx.SetID(types.ATXID(atxid)) cfg := &checkpoint.RecoverConfig{ GoldenAtx: goldenAtx, diff --git a/checkpoint/runner.go b/checkpoint/runner.go index 2b72e236b0..7039aa54c0 100644 --- a/checkpoint/runner.go +++ b/checkpoint/runner.go @@ -78,10 +78,15 @@ func checkpointDB( if mal, ok := malicious[catx.SmesherID]; ok && mal { continue } + var marriageAtx []byte + if catx.MarriageATX != nil { + marriageAtx = catx.MarriageATX.Bytes() + } checkpoint.Data.Atxs = append(checkpoint.Data.Atxs, types.AtxSnapshot{ ID: catx.ID.Bytes(), Epoch: catx.Epoch.Uint32(), CommitmentAtx: catx.CommitmentATX.Bytes(), + MarriageAtx: marriageAtx, VrfNonce: uint64(catx.VRFNonce), NumUnits: catx.NumUnits, BaseTickHeight: catx.BaseTickHeight, diff --git a/checkpoint/runner_test.go b/checkpoint/runner_test.go index f7009c24ec..472f62a8be 100644 --- a/checkpoint/runner_test.go +++ b/checkpoint/runner_test.go @@ -238,10 +238,15 @@ func newAtx( } func asAtxSnapshot(v *types.ActivationTx, cmt *types.ATXID) types.AtxSnapshot { + var marriageATX []byte + if v.MarriageATX != nil { + marriageATX = v.MarriageATX.Bytes() + } return types.AtxSnapshot{ ID: v.ID().Bytes(), Epoch: v.PublishEpoch.Uint32(), CommitmentAtx: cmt.Bytes(), + MarriageAtx: marriageATX, VrfNonce: uint64(v.VRFNonce), NumUnits: v.NumUnits, BaseTickHeight: v.BaseTickHeight, @@ -375,3 +380,35 @@ func TestRunner_Generate_Error(t *testing.T) { require.Error(t, err) }) } + +func TestRunner_Generate_PreservesMarriageATX(t *testing.T) { + t.Parallel() + db := sql.InMemory() + + require.NoError(t, accounts.Update(db, &types.Account{Address: types.Address{1, 1}})) + + atx := &types.ActivationTx{ + CommitmentATX: &types.ATXID{1, 2, 3, 4, 5}, + MarriageATX: &types.ATXID{6, 7, 8, 9}, + SmesherID: types.RandomNodeID(), + NumUnits: 4, + } + atx.SetID(types.RandomATXID()) + require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) + require.NoError(t, atxs.SetUnits(db, atx.ID(), atx.SmesherID, atx.NumUnits)) + + fs := afero.NewMemMapFs() + dir, err := afero.TempDir(fs, "", "Generate") + require.NoError(t, err) + + err = checkpoint.Generate(context.Background(), fs, db, dir, 5, 2) + require.NoError(t, err) + + file, err := fs.Open(checkpoint.SelfCheckpointFilename(dir, 5)) + require.NoError(t, err) + defer file.Close() + + var checkpoint types.Checkpoint + require.NoError(t, json.NewDecoder(file).Decode(&checkpoint)) + require.Equal(t, atx.MarriageATX.Bytes(), checkpoint.Data.Atxs[0].MarriageAtx) +} diff --git a/common/types/activation.go b/common/types/activation.go index 41112efc4c..4a1da34c15 100644 --- a/common/types/activation.go +++ b/common/types/activation.go @@ -178,7 +178,9 @@ type ActivationTx struct { PrevATXID ATXID // CommitmentATX is the ATX used in the commitment for initializing the PoST of the node. - CommitmentATX *ATXID + CommitmentATX *ATXID + // The marriage ATX, used in merged ATXs only. + MarriageATX *ATXID Coinbase Address NumUnits uint32 // the minimum number of space units in this and the previous ATX BaseTickHeight uint64 @@ -231,6 +233,9 @@ func (atx *ActivationTx) MarshalLogObject(encoder log.ObjectEncoder) error { if atx.CommitmentATX != nil { encoder.AddString("commitment_atx_id", atx.CommitmentATX.String()) } + if atx.MarriageATX != nil { + encoder.AddString("marriage_atx_id", atx.MarriageATX.String()) + } encoder.AddUint64("vrf_nonce", uint64(atx.VRFNonce)) encoder.AddString("coinbase", atx.Coinbase.String()) encoder.AddUint32("epoch", atx.PublishEpoch.Uint32()) diff --git a/common/types/checkpoint.go b/common/types/checkpoint.go index 7f04b35a87..81184e6b30 100644 --- a/common/types/checkpoint.go +++ b/common/types/checkpoint.go @@ -17,6 +17,7 @@ type AtxSnapshot struct { ID []byte `json:"id"` Epoch uint32 `json:"epoch"` CommitmentAtx []byte `json:"commitmentAtx"` + MarriageAtx []byte `json:"marriageAtx"` VrfNonce uint64 `json:"vrfNonce"` BaseTickHeight uint64 `json:"baseTickHeight"` TickCount uint64 `json:"tickCount"` diff --git a/sql/atxs/atxs.go b/sql/atxs/atxs.go index d7a41d7167..5e14cddde1 100644 --- a/sql/atxs/atxs.go +++ b/sql/atxs/atxs.go @@ -22,7 +22,8 @@ const ( // filters that refer to the id column. const fieldsQuery = `select atxs.id, atxs.nonce, atxs.base_tick_height, atxs.tick_count, atxs.pubkey, atxs.effective_num_units, -atxs.received, atxs.epoch, atxs.sequence, atxs.coinbase, atxs.validity, atxs.prev_id, atxs.commitment_atx, atxs.weight` +atxs.received, atxs.epoch, atxs.sequence, atxs.coinbase, atxs.validity, atxs.prev_id, atxs.commitment_atx, atxs.weight, +atxs.marriage_atx` const fullQuery = fieldsQuery + ` from atxs` @@ -62,6 +63,10 @@ func decoder(fn decoderCallback) sql.Decoder { stmt.ColumnBytes(12, a.CommitmentATX[:]) } a.Weight = uint64(stmt.ColumnInt64(13)) + if stmt.ColumnType(14) != sqlite.SQLITE_NULL { + a.MarriageATX = new(types.ATXID) + stmt.ColumnBytes(14, a.MarriageATX[:]) + } return fn(&a) } @@ -425,8 +430,6 @@ func Add(db sql.Executor, atx *types.ActivationTx, blob types.AtxBlob) error { stmt.BindInt64(3, int64(atx.NumUnits)) if atx.CommitmentATX != nil { stmt.BindBytes(4, atx.CommitmentATX.Bytes()) - } else { - stmt.BindNull(4) } stmt.BindInt64(5, int64(atx.VRFNonce)) stmt.BindBytes(6, atx.SmesherID.Bytes()) @@ -438,17 +441,18 @@ func Add(db sql.Executor, atx *types.ActivationTx, blob types.AtxBlob) error { stmt.BindInt64(12, int64(atx.Validity())) if atx.PrevATXID != types.EmptyATXID { stmt.BindBytes(13, atx.PrevATXID.Bytes()) - } else { - stmt.BindNull(13) } stmt.BindInt64(14, int64(atx.Weight)) + if atx.MarriageATX != nil { + stmt.BindBytes(15, atx.MarriageATX.Bytes()) + } } _, err := db.Exec(` insert into atxs (id, epoch, effective_num_units, commitment_atx, nonce, pubkey, received, base_tick_height, tick_count, sequence, coinbase, - validity, prev_id, weight) - values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14)`, enc, nil) + validity, prev_id, weight, marriage_atx) + values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15)`, enc, nil) if err != nil { return fmt.Errorf("insert ATX ID %v: %w", atx.ID(), err) } @@ -539,6 +543,7 @@ type CheckpointAtx struct { ID types.ATXID Epoch types.EpochID CommitmentATX types.ATXID + MarriageATX *types.ATXID VRFNonce types.VRFPostIndex BaseTickHeight uint64 TickCount uint64 @@ -571,16 +576,21 @@ func LatestN(db sql.Executor, n int) ([]CheckpointAtx, error) { catx.Sequence = uint64(stmt.ColumnInt64(6)) stmt.ColumnBytes(7, catx.Coinbase[:]) catx.VRFNonce = types.VRFPostIndex(stmt.ColumnInt64(8)) - catx.Units = make(map[types.NodeID]uint32) + if stmt.ColumnType(9) != sqlite.SQLITE_NULL { + catx.MarriageATX = new(types.ATXID) + stmt.ColumnBytes(9, catx.MarriageATX[:]) + } rst = append(rst, catx) return true } rows, err := db.Exec(` - select id, epoch, effective_num_units, base_tick_height, tick_count, pubkey, sequence, coinbase, nonce + select + id, epoch, effective_num_units, base_tick_height, tick_count, pubkey, sequence, coinbase, nonce, marriage_atx from ( select row_number() over (partition by pubkey order by epoch desc) RowNum, - id, epoch, effective_num_units, base_tick_height, tick_count, pubkey, sequence, coinbase, nonce + id, epoch, effective_num_units, base_tick_height, tick_count, pubkey, sequence, coinbase, nonce, + marriage_atx from atxs ) where RowNum <= ?1 order by pubkey;`, enc, dec) @@ -616,12 +626,15 @@ func AddCheckpointed(db sql.Executor, catx *CheckpointAtx) error { stmt.BindInt64(8, int64(catx.Sequence)) stmt.BindBytes(9, catx.SmesherID.Bytes()) stmt.BindBytes(10, catx.Coinbase.Bytes()) + if catx.MarriageATX != nil { + stmt.BindBytes(11, catx.MarriageATX.Bytes()) + } } _, err := db.Exec(` insert into atxs (id, epoch, effective_num_units, commitment_atx, nonce, - base_tick_height, tick_count, sequence, pubkey, coinbase, received) - values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, 0)`, enc, nil) + base_tick_height, tick_count, sequence, pubkey, coinbase, marriage_atx, received) + values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, 0)`, enc, nil) if err != nil { return fmt.Errorf("insert checkpoint ATX %v: %w", catx.ID, err) } @@ -803,7 +816,7 @@ func IterateAtxsWithMalfeasance( func(s *sql.Statement) { s.BindInt64(1, int64(publish)) }, func(s *sql.Statement) bool { return decoder(func(atx *types.ActivationTx) bool { - return fn(atx, s.ColumnInt(14) != 0) + return fn(atx, s.ColumnInt(15) != 0) })(s) }, ) @@ -995,3 +1008,28 @@ func AtxWithPrevious(db sql.Executor, prev types.ATXID, id types.NodeID) (types. } return atxid, nil } + +// Find 2 distinct merged ATXs (having the same marriage ATX) in the same epoch. +func MergeConflict(db sql.Executor, marriage types.ATXID, publish types.EpochID) ([]types.ATXID, error) { + var ids []types.ATXID + rows, err := db.Exec(` + SELECT id FROM atxs WHERE marriage_atx = ?1 and epoch = ?2;`, + func(stmt *sql.Statement) { + stmt.BindBytes(1, marriage.Bytes()) + stmt.BindInt64(2, int64(publish)) + }, + func(stmt *sql.Statement) bool { + var id types.ATXID + stmt.ColumnBytes(0, id[:]) + ids = append(ids, id) + return len(ids) < 2 + }, + ) + if err != nil { + return nil, err + } + if rows != 2 { + return nil, sql.ErrNotFound + } + return ids, nil +} diff --git a/sql/atxs/atxs_test.go b/sql/atxs/atxs_test.go index 06d778c639..1dd1914968 100644 --- a/sql/atxs/atxs_test.go +++ b/sql/atxs/atxs_test.go @@ -1278,3 +1278,55 @@ func Test_FindDoublePublish(t *testing.T) { require.ElementsMatch(t, []types.ATXID{atx0.ID(), atx1.ID()}, atxIDs) }) } + +func Test_MergeConflict(t *testing.T) { + t.Parallel() + t.Run("no atxs", func(t *testing.T) { + t.Parallel() + db := sql.InMemory() + _, err := atxs.MergeConflict(db, types.RandomATXID(), 0) + require.ErrorIs(t, err, sql.ErrNotFound) + }) + t.Run("no conflict", func(t *testing.T) { + t.Parallel() + db := sql.InMemory() + marriage := types.RandomATXID() + + atx := types.ActivationTx{MarriageATX: &marriage} + atx.SetID(types.RandomATXID()) + require.NoError(t, atxs.Add(db, &atx, types.AtxBlob{})) + + _, err := atxs.MergeConflict(db, types.RandomATXID(), atx.PublishEpoch) + require.ErrorIs(t, err, sql.ErrNotFound) + }) + t.Run("finds conflict", func(t *testing.T) { + t.Parallel() + db := sql.InMemory() + marriage := types.RandomATXID() + + atx0 := types.ActivationTx{MarriageATX: &marriage} + atx0.SetID(types.RandomATXID()) + require.NoError(t, atxs.Add(db, &atx0, types.AtxBlob{})) + + atx1 := types.ActivationTx{MarriageATX: &marriage} + atx1.SetID(types.RandomATXID()) + require.NoError(t, atxs.Add(db, &atx1, types.AtxBlob{})) + + ids, err := atxs.MergeConflict(db, marriage, atx0.PublishEpoch) + require.NoError(t, err) + require.ElementsMatch(t, []types.ATXID{atx0.ID(), atx1.ID()}, ids) + + // filters by epoch + _, err = atxs.MergeConflict(db, types.RandomATXID(), 8) + require.ErrorIs(t, err, sql.ErrNotFound) + + // returns only 2 ATXs + atx2 := types.ActivationTx{MarriageATX: &marriage} + atx2.SetID(types.RandomATXID()) + require.NoError(t, atxs.Add(db, &atx2, types.AtxBlob{})) + + ids, err = atxs.MergeConflict(db, marriage, atx0.PublishEpoch) + require.NoError(t, err) + require.Len(t, ids, 2) + }) +} diff --git a/sql/migrations/state/0020_atx_merge.sql b/sql/migrations/state/0020_atx_merge.sql index 17bcda9c83..8dbff567c0 100644 --- a/sql/migrations/state/0020_atx_merge.sql +++ b/sql/migrations/state/0020_atx_merge.sql @@ -1,5 +1,6 @@ -- Changes required to handle merged ATXs +ALTER TABLE atxs ADD COLUMN marriage_atx CHAR(32); ALTER TABLE atxs ADD COLUMN weight INTEGER; UPDATE atxs SET weight = effective_num_units * tick_count;