From d21910b80c7ac39be486eee2f12300de6b32b84c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bartosz=20R=C3=B3=C5=BCa=C5=84ski?= Date: Tue, 13 Aug 2024 11:54:47 +0000 Subject: [PATCH] detect invalid previous ATX for V2 ATXs (#6189) ## Motivation Add a check to detect whether the previous ATX of a V2 ATX is the correct one. --- activation/handler_v2.go | 61 +++++++++++++++++++++---- activation/handler_v2_test.go | 44 ++++++++++++++++++ activation/wire/malfeasance.go | 9 ++-- sql/atxs/atxs.go | 65 ++++++++++---------------- sql/atxs/atxs_test.go | 83 ++++++++++++++++++++++++++++------ 5 files changed, 193 insertions(+), 69 deletions(-) diff --git a/activation/handler_v2.go b/activation/handler_v2.go index 6755b9f7e0..99342c43e8 100644 --- a/activation/handler_v2.go +++ b/activation/handler_v2.go @@ -699,6 +699,14 @@ func (h *HandlerV2) checkMalicious(ctx context.Context, tx *sql.Tx, atx *activat return nil } + malicious, err = h.checkPrevAtx(ctx, tx, atx) + if err != nil { + return fmt.Errorf("checking previous ATX: %w", err) + } + if malicious { + return nil + } + // TODO(mafa): contextual validation: // 1. check double-publish = ID contributed post to two ATXs in the same epoch // 2. check previous ATX @@ -762,29 +770,66 @@ func (h *HandlerV2) checkDoublePost(ctx context.Context, tx *sql.Tx, atx *activa return false, nil } -func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, watx *activationTx) (bool, error) { - if watx.MarriageATX == nil { +func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, atx *activationTx) (bool, error) { + if atx.MarriageATX == nil { return false, nil } - ids, err := atxs.MergeConflict(tx, *watx.MarriageATX, watx.PublishEpoch) + ids, err := atxs.MergeConflict(tx, *atx.MarriageATX, atx.PublishEpoch) switch { case errors.Is(err, sql.ErrNotFound): return false, nil case err != nil: return false, fmt.Errorf("searching for ATXs with the same marriage ATX: %w", err) } - otherIndex := slices.IndexFunc(ids, func(id types.ATXID) bool { return id != watx.ID() }) + otherIndex := slices.IndexFunc(ids, func(id types.ATXID) bool { return id != atx.ID() }) other := ids[otherIndex] h.logger.Debug("second merged ATX for single marriage - creating malfeasance proof", - zap.Stringer("marriage_atx", *watx.MarriageATX), - zap.Stringer("atx", watx.ID()), + zap.Stringer("marriage_atx", *atx.MarriageATX), + zap.Stringer("atx", atx.ID()), zap.Stringer("other_atx", other), - zap.Stringer("smesher_id", watx.SmesherID), + zap.Stringer("smesher_id", atx.SmesherID), ) var proof wire.Proof - return true, h.malPublisher.Publish(ctx, watx.SmesherID, proof) + return true, h.malPublisher.Publish(ctx, atx.SmesherID, proof) +} + +func (h *HandlerV2) checkPrevAtx(ctx context.Context, tx *sql.Tx, atx *activationTx) (bool, error) { + for id, data := range atx.ids { + expectedPrevID, err := atxs.PrevIDByNodeID(tx, id, atx.PublishEpoch) + if err != nil && !errors.Is(err, sql.ErrNotFound) { + return false, fmt.Errorf("get last atx by node id: %w", err) + } + if expectedPrevID == data.previous { + continue + } + + h.logger.Debug("atx references a wrong previous ATX", + log.ZShortStringer("smesherID", id), + log.ZShortStringer("actual", data.previous), + log.ZShortStringer("expected", expectedPrevID), + ) + + atx1, atx2, err := atxs.PrevATXCollision(tx, data.previous, id) + switch { + case errors.Is(err, sql.ErrNotFound): + continue + case err != nil: + return false, fmt.Errorf("checking for previous ATX collision: %w", err) + } + + h.logger.Debug("creating a malfeasance proof for invalid previous ATX", + log.ZShortStringer("smesherID", id), + log.ZShortStringer("atx1", atx1), + log.ZShortStringer("atx2", atx2), + ) + + // TODO(mafa): finish proof + var proof wire.Proof + return true, h.malPublisher.Publish(ctx, id, proof) + } + return false, nil } // Store an ATX in the DB. diff --git a/activation/handler_v2_test.go b/activation/handler_v2_test.go index 46f6e76834..b596a2ed2c 100644 --- a/activation/handler_v2_test.go +++ b/activation/handler_v2_test.go @@ -1894,6 +1894,50 @@ func Test_CalculatingUnits(t *testing.T) { }) } +func TestContextual_PreviousATX(t *testing.T) { + golden := types.RandomATXID() + atxHndlr := newV2TestHandler(t, golden) + var ( + signers []*signing.EdSigner + eqSet []types.NodeID + ) + for range 3 { + sig, err := signing.NewEdSigner() + require.NoError(t, err) + signers = append(signers, sig) + eqSet = append(eqSet, sig.NodeID()) + } + + mATX, otherAtxs := marryIDs(t, atxHndlr, signers, golden) + + // signer 1 creates a solo ATX + soloAtx := newSoloATXv2(t, mATX.PublishEpoch+1, otherAtxs[0].ID(), mATX.ID()) + soloAtx.Sign(signers[1]) + atxHndlr.expectAtxV2(soloAtx) + err := atxHndlr.processATX(context.Background(), "", soloAtx, time.Now()) + require.NoError(t, err) + + // create a MergedATX for all IDs + merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) + post := wire.SubPostV2{ + MarriageIndex: 1, + PrevATXIndex: 1, + NumUnits: soloAtx.TotalNumUnits(), + } + merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + // Pass a wrong previous ATX for signer 1. It's already been used for soloATX + // (which should be used for the previous ATX for signer 1). + merged.PreviousATXs = append(merged.PreviousATXs, otherAtxs[0].ID()) + matxID := mATX.ID() + merged.MarriageATX = &matxID + merged.Sign(signers[0]) + + atxHndlr.expectMergedAtxV2(merged, eqSet, []uint64{100}) + atxHndlr.mMalPublish.EXPECT().Publish(gomock.Any(), signers[1].NodeID(), gomock.Any()) + err = atxHndlr.processATX(context.Background(), "", merged, time.Now()) + require.NoError(t, err) +} + func Test_CalculatingWeight(t *testing.T) { t.Parallel() t.Run("total weight must not overflow uint64", func(t *testing.T) { diff --git a/activation/wire/malfeasance.go b/activation/wire/malfeasance.go index 019a52d6cc..c857dd075b 100644 --- a/activation/wire/malfeasance.go +++ b/activation/wire/malfeasance.go @@ -33,10 +33,11 @@ const ( LegacyInvalidPost ProofType = 0x01 LegacyInvalidPrevATX ProofType = 0x02 - DoublePublish ProofType = 0x10 - DoubleMarry ProofType = 0x11 - DoubleMerge ProofType = 0x12 - InvalidPost ProofType = 0x13 + DoublePublish ProofType = 0x10 + DoubleMarry ProofType = 0x11 + DoubleMerge ProofType = 0x12 + InvalidPost ProofType = 0x13 + InvalidPrevious ProofType = 0x14 ) // ProofVersion is an identifier for the version of the proof that is encoded in the ATXProof. diff --git a/sql/atxs/atxs.go b/sql/atxs/atxs.go index 248dd99a0d..9d749a035b 100644 --- a/sql/atxs/atxs.go +++ b/sql/atxs/atxs.go @@ -245,7 +245,8 @@ func GetLastIDByNodeID(db sql.Executor, nodeID types.NodeID) (id types.ATXID, er } // PrevIDByNodeID returns the previous ATX ID for a given node ID and public epoch. -// It returns the newest ATX ID that was published before the given public epoch. +// It returns the newest ATX ID containing PoST of the given node ID +// that was published before the given public epoch. func PrevIDByNodeID(db sql.Executor, nodeID types.NodeID, pubEpoch types.EpochID) (id types.ATXID, err error) { enc := func(stmt *sql.Statement) { stmt.BindBytes(1, nodeID.Bytes()) @@ -257,10 +258,10 @@ func PrevIDByNodeID(db sql.Executor, nodeID types.NodeID, pubEpoch types.EpochID } if rows, err := db.Exec(` - select id from atxs - where pubkey = ?1 and epoch < ?2 - order by epoch desc - limit 1;`, enc, dec); err != nil { + SELECT posts.atxid FROM posts JOIN atxs ON posts.atxid = atxs.id + WHERE posts.pubkey = ?1 AND atxs.epoch < ?2 + ORDER BY atxs.epoch DESC + LIMIT 1;`, enc, dec); err != nil { return types.EmptyATXID, fmt.Errorf("exec nodeID %v, epoch %d: %w", nodeID, pubEpoch, err) } else if rows == 0 { return types.EmptyATXID, fmt.Errorf("exec nodeID %s, epoch %d: %w", nodeID, pubEpoch, sql.ErrNotFound) @@ -861,46 +862,26 @@ func IterateAtxIdsWithMalfeasance( return err } -type PrevATXCollision struct { - NodeID1 types.NodeID - ATX1 types.ATXID - - NodeID2 types.NodeID - ATX2 types.ATXID -} - -func PrevATXCollisions(db sql.Executor) ([]PrevATXCollision, error) { - var result []PrevATXCollision - +func PrevATXCollision(db sql.Executor, prev types.ATXID, id types.NodeID) (types.ATXID, types.ATXID, error) { + var atxs []types.ATXID + enc := func(stmt *sql.Statement) { + stmt.BindBytes(1, prev[:]) + stmt.BindBytes(2, id[:]) + } dec := func(stmt *sql.Statement) bool { - var nodeID1, nodeID2 types.NodeID - stmt.ColumnBytes(0, nodeID1[:]) - stmt.ColumnBytes(1, nodeID2[:]) - - var id1, id2 types.ATXID - stmt.ColumnBytes(2, id1[:]) - stmt.ColumnBytes(3, id2[:]) - - result = append(result, PrevATXCollision{ - NodeID1: nodeID1, - ATX1: id1, - - NodeID2: nodeID2, - ATX2: id2, - }) - return true + var id types.ATXID + stmt.ColumnBytes(0, id[:]) + atxs = append(atxs, id) + return len(atxs) < 2 } - // we are joining the table with itself to find ATXs with the same prevATX - // the WHERE clause ensures that we only get the pairs once - if _, err := db.Exec(` - SELECT p1.pubkey, p2.pubkey, p1.atxid, p2.atxid - FROM posts p1 - INNER JOIN posts p2 ON p1.prev_atxid = p2.prev_atxid - WHERE p1.atxid < p2.atxid;`, nil, dec); err != nil { - return nil, fmt.Errorf("error getting ATXs with same prevATX: %w", err) + _, err := db.Exec("SELECT atxid FROM posts WHERE prev_atxid = ?1 AND pubkey = ?2;", enc, dec) + if err != nil { + return types.EmptyATXID, types.EmptyATXID, fmt.Errorf("error getting ATXs with same prevATX: %w", err) } - - return result, nil + if len(atxs) != 2 { + return types.EmptyATXID, types.EmptyATXID, sql.ErrNotFound + } + return atxs[0], atxs[1], nil } func Units(db sql.Executor, atxID types.ATXID, nodeID types.NodeID) (uint32, error) { diff --git a/sql/atxs/atxs_test.go b/sql/atxs/atxs_test.go index eb507fe969..aaebce0755 100644 --- a/sql/atxs/atxs_test.go +++ b/sql/atxs/atxs_test.go @@ -1023,7 +1023,7 @@ func TestLatest(t *testing.T) { } } -func Test_PrevATXCollisions(t *testing.T) { +func Test_PrevATXCollision(t *testing.T) { db := sql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -1048,29 +1048,29 @@ func Test_PrevATXCollisions(t *testing.T) { require.NoError(t, err) require.Equal(t, atx2, got2) - // add 10 valid ATXs by 10 other smeshers + // add 10 valid ATXs by 10 other smeshers, using the same previous but no collision + var otherIds []types.NodeID for i := 2; i < 6; i++ { otherSig, err := signing.NewEdSigner() require.NoError(t, err) + otherIds = append(otherIds, otherSig.NodeID()) - atx, blob := newAtx(t, otherSig, withPublishEpoch(types.EpochID(i))) - require.NoError(t, atxs.Add(db, atx, blob)) - - atx2, blob2 := newAtx(t, otherSig, - withPublishEpoch(types.EpochID(i+1)), - ) + atx2, blob2 := newAtx(t, otherSig, withPublishEpoch(types.EpochID(i+1))) require.NoError(t, atxs.Add(db, atx2, blob2)) - require.NoError(t, atxs.SetPost(db, atx2.ID(), atx.ID(), 0, sig.NodeID(), 10)) + require.NoError(t, atxs.SetPost(db, atx2.ID(), prevATXID, 0, atx2.SmesherID, 10)) } - // get the collisions - got, err := atxs.PrevATXCollisions(db) + collision1, collision2, err := atxs.PrevATXCollision(db, prevATXID, sig.NodeID()) require.NoError(t, err) - require.Len(t, got, 1) + require.ElementsMatch(t, []types.ATXID{atx1.ID(), atx2.ID()}, []types.ATXID{collision1, collision2}) - require.Equal(t, sig.NodeID(), got[0].NodeID1) - require.Equal(t, sig.NodeID(), got[0].NodeID2) - require.ElementsMatch(t, []types.ATXID{atx1.ID(), atx2.ID()}, []types.ATXID{got[0].ATX1, got[0].ATX2}) + _, _, err = atxs.PrevATXCollision(db, types.RandomATXID(), sig.NodeID()) + require.ErrorIs(t, err, sql.ErrNotFound) + + for _, id := range append(otherIds, types.RandomNodeID()) { + _, _, err := atxs.PrevATXCollision(db, prevATXID, id) + require.ErrorIs(t, err, sql.ErrNotFound) + } } func TestCoinbase(t *testing.T) { @@ -1362,3 +1362,56 @@ func Test_Previous(t *testing.T) { require.Equal(t, previousAtxs, got) }) } + +func TestPrevIDByNodeID(t *testing.T) { + t.Run("no previous ATXs", func(t *testing.T) { + db := sql.InMemory() + _, err := atxs.PrevIDByNodeID(db, types.RandomNodeID(), 0) + require.ErrorIs(t, err, sql.ErrNotFound) + }) + t.Run("filters by epoch", func(t *testing.T) { + db := sql.InMemory() + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + atx1, blob1 := newAtx(t, sig, withPublishEpoch(1)) + require.NoError(t, atxs.Add(db, atx1, blob1)) + require.NoError(t, atxs.SetPost(db, atx1.ID(), types.EmptyATXID, 0, sig.NodeID(), 4)) + + atx2, blob2 := newAtx(t, sig, withPublishEpoch(2)) + require.NoError(t, atxs.Add(db, atx2, blob2)) + require.NoError(t, atxs.SetPost(db, atx2.ID(), types.EmptyATXID, 0, sig.NodeID(), 4)) + + _, err = atxs.PrevIDByNodeID(db, sig.NodeID(), 1) + require.ErrorIs(t, err, sql.ErrNotFound) + + prevID, err := atxs.PrevIDByNodeID(db, sig.NodeID(), 2) + require.NoError(t, err) + require.Equal(t, atx1.ID(), prevID) + + prevID, err = atxs.PrevIDByNodeID(db, sig.NodeID(), 3) + require.NoError(t, err) + require.Equal(t, atx2.ID(), prevID) + }) + t.Run("the previous is merged and ID is not the signer", func(t *testing.T) { + db := sql.InMemory() + sig, err := signing.NewEdSigner() + require.NoError(t, err) + id := types.RandomNodeID() + + atx1, blob1 := newAtx(t, sig, withPublishEpoch(1)) + require.NoError(t, atxs.Add(db, atx1, blob1)) + require.NoError(t, atxs.SetPost(db, atx1.ID(), types.EmptyATXID, 0, sig.NodeID(), 4)) + require.NoError(t, atxs.SetPost(db, atx1.ID(), types.EmptyATXID, 0, id, 8)) + require.NoError(t, atxs.SetPost(db, atx1.ID(), types.EmptyATXID, 0, types.RandomNodeID(), 12)) + + atx2, blob2 := newAtx(t, sig, withPublishEpoch(2)) + require.NoError(t, atxs.Add(db, atx2, blob2)) + require.NoError(t, atxs.SetPost(db, atx2.ID(), atx1.ID(), 0, sig.NodeID(), 4)) + require.NoError(t, atxs.SetPost(db, atx2.ID(), atx1.ID(), 0, types.RandomNodeID(), 12)) + + prevID, err := atxs.PrevIDByNodeID(db, id, 3) + require.NoError(t, err) + require.Equal(t, atx1.ID(), prevID) + }) +}