Skip to content

Commit

Permalink
detect invalid previous ATX for V2 ATXs (#6189)
Browse files Browse the repository at this point in the history
## Motivation

Add a check to detect whether the previous ATX of a V2 ATX is the correct one.
  • Loading branch information
poszu committed Aug 13, 2024
1 parent f135361 commit d21910b
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 69 deletions.
61 changes: 53 additions & 8 deletions activation/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,14 @@ func (h *HandlerV2) checkMalicious(ctx context.Context, tx *sql.Tx, atx *activat
return nil
}

malicious, err = h.checkPrevAtx(ctx, tx, atx)
if err != nil {
return fmt.Errorf("checking previous ATX: %w", err)
}
if malicious {
return nil
}

// TODO(mafa): contextual validation:
// 1. check double-publish = ID contributed post to two ATXs in the same epoch
// 2. check previous ATX
Expand Down Expand Up @@ -762,29 +770,66 @@ func (h *HandlerV2) checkDoublePost(ctx context.Context, tx *sql.Tx, atx *activa
return false, nil
}

func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, watx *activationTx) (bool, error) {
if watx.MarriageATX == nil {
func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, atx *activationTx) (bool, error) {
if atx.MarriageATX == nil {
return false, nil
}
ids, err := atxs.MergeConflict(tx, *watx.MarriageATX, watx.PublishEpoch)
ids, err := atxs.MergeConflict(tx, *atx.MarriageATX, atx.PublishEpoch)
switch {
case errors.Is(err, sql.ErrNotFound):
return false, nil
case err != nil:
return false, fmt.Errorf("searching for ATXs with the same marriage ATX: %w", err)
}
otherIndex := slices.IndexFunc(ids, func(id types.ATXID) bool { return id != watx.ID() })
otherIndex := slices.IndexFunc(ids, func(id types.ATXID) bool { return id != atx.ID() })
other := ids[otherIndex]

h.logger.Debug("second merged ATX for single marriage - creating malfeasance proof",
zap.Stringer("marriage_atx", *watx.MarriageATX),
zap.Stringer("atx", watx.ID()),
zap.Stringer("marriage_atx", *atx.MarriageATX),
zap.Stringer("atx", atx.ID()),
zap.Stringer("other_atx", other),
zap.Stringer("smesher_id", watx.SmesherID),
zap.Stringer("smesher_id", atx.SmesherID),
)

var proof wire.Proof
return true, h.malPublisher.Publish(ctx, watx.SmesherID, proof)
return true, h.malPublisher.Publish(ctx, atx.SmesherID, proof)
}

func (h *HandlerV2) checkPrevAtx(ctx context.Context, tx *sql.Tx, atx *activationTx) (bool, error) {
for id, data := range atx.ids {
expectedPrevID, err := atxs.PrevIDByNodeID(tx, id, atx.PublishEpoch)
if err != nil && !errors.Is(err, sql.ErrNotFound) {
return false, fmt.Errorf("get last atx by node id: %w", err)
}
if expectedPrevID == data.previous {
continue
}

h.logger.Debug("atx references a wrong previous ATX",
log.ZShortStringer("smesherID", id),
log.ZShortStringer("actual", data.previous),
log.ZShortStringer("expected", expectedPrevID),
)

atx1, atx2, err := atxs.PrevATXCollision(tx, data.previous, id)
switch {
case errors.Is(err, sql.ErrNotFound):
continue
case err != nil:
return false, fmt.Errorf("checking for previous ATX collision: %w", err)
}

h.logger.Debug("creating a malfeasance proof for invalid previous ATX",
log.ZShortStringer("smesherID", id),
log.ZShortStringer("atx1", atx1),
log.ZShortStringer("atx2", atx2),
)

// TODO(mafa): finish proof
var proof wire.Proof
return true, h.malPublisher.Publish(ctx, id, proof)
}
return false, nil
}

// Store an ATX in the DB.
Expand Down
44 changes: 44 additions & 0 deletions activation/handler_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1894,6 +1894,50 @@ func Test_CalculatingUnits(t *testing.T) {
})
}

func TestContextual_PreviousATX(t *testing.T) {
golden := types.RandomATXID()
atxHndlr := newV2TestHandler(t, golden)
var (
signers []*signing.EdSigner
eqSet []types.NodeID
)
for range 3 {
sig, err := signing.NewEdSigner()
require.NoError(t, err)
signers = append(signers, sig)
eqSet = append(eqSet, sig.NodeID())
}

mATX, otherAtxs := marryIDs(t, atxHndlr, signers, golden)

// signer 1 creates a solo ATX
soloAtx := newSoloATXv2(t, mATX.PublishEpoch+1, otherAtxs[0].ID(), mATX.ID())
soloAtx.Sign(signers[1])
atxHndlr.expectAtxV2(soloAtx)
err := atxHndlr.processATX(context.Background(), "", soloAtx, time.Now())
require.NoError(t, err)

// create a MergedATX for all IDs
merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID())
post := wire.SubPostV2{
MarriageIndex: 1,
PrevATXIndex: 1,
NumUnits: soloAtx.TotalNumUnits(),
}
merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post)
// Pass a wrong previous ATX for signer 1. It's already been used for soloATX
// (which should be used for the previous ATX for signer 1).
merged.PreviousATXs = append(merged.PreviousATXs, otherAtxs[0].ID())
matxID := mATX.ID()
merged.MarriageATX = &matxID
merged.Sign(signers[0])

atxHndlr.expectMergedAtxV2(merged, eqSet, []uint64{100})
atxHndlr.mMalPublish.EXPECT().Publish(gomock.Any(), signers[1].NodeID(), gomock.Any())
err = atxHndlr.processATX(context.Background(), "", merged, time.Now())
require.NoError(t, err)
}

func Test_CalculatingWeight(t *testing.T) {
t.Parallel()
t.Run("total weight must not overflow uint64", func(t *testing.T) {
Expand Down
9 changes: 5 additions & 4 deletions activation/wire/malfeasance.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ const (
LegacyInvalidPost ProofType = 0x01
LegacyInvalidPrevATX ProofType = 0x02

DoublePublish ProofType = 0x10
DoubleMarry ProofType = 0x11
DoubleMerge ProofType = 0x12
InvalidPost ProofType = 0x13
DoublePublish ProofType = 0x10
DoubleMarry ProofType = 0x11
DoubleMerge ProofType = 0x12
InvalidPost ProofType = 0x13
InvalidPrevious ProofType = 0x14
)

// ProofVersion is an identifier for the version of the proof that is encoded in the ATXProof.
Expand Down
65 changes: 23 additions & 42 deletions sql/atxs/atxs.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ func GetLastIDByNodeID(db sql.Executor, nodeID types.NodeID) (id types.ATXID, er
}

// PrevIDByNodeID returns the previous ATX ID for a given node ID and public epoch.
// It returns the newest ATX ID that was published before the given public epoch.
// It returns the newest ATX ID containing PoST of the given node ID
// that was published before the given public epoch.
func PrevIDByNodeID(db sql.Executor, nodeID types.NodeID, pubEpoch types.EpochID) (id types.ATXID, err error) {
enc := func(stmt *sql.Statement) {
stmt.BindBytes(1, nodeID.Bytes())
Expand All @@ -257,10 +258,10 @@ func PrevIDByNodeID(db sql.Executor, nodeID types.NodeID, pubEpoch types.EpochID
}

if rows, err := db.Exec(`
select id from atxs
where pubkey = ?1 and epoch < ?2
order by epoch desc
limit 1;`, enc, dec); err != nil {
SELECT posts.atxid FROM posts JOIN atxs ON posts.atxid = atxs.id
WHERE posts.pubkey = ?1 AND atxs.epoch < ?2
ORDER BY atxs.epoch DESC
LIMIT 1;`, enc, dec); err != nil {
return types.EmptyATXID, fmt.Errorf("exec nodeID %v, epoch %d: %w", nodeID, pubEpoch, err)
} else if rows == 0 {
return types.EmptyATXID, fmt.Errorf("exec nodeID %s, epoch %d: %w", nodeID, pubEpoch, sql.ErrNotFound)
Expand Down Expand Up @@ -861,46 +862,26 @@ func IterateAtxIdsWithMalfeasance(
return err
}

type PrevATXCollision struct {
NodeID1 types.NodeID
ATX1 types.ATXID

NodeID2 types.NodeID
ATX2 types.ATXID
}

func PrevATXCollisions(db sql.Executor) ([]PrevATXCollision, error) {
var result []PrevATXCollision

func PrevATXCollision(db sql.Executor, prev types.ATXID, id types.NodeID) (types.ATXID, types.ATXID, error) {
var atxs []types.ATXID
enc := func(stmt *sql.Statement) {
stmt.BindBytes(1, prev[:])
stmt.BindBytes(2, id[:])
}
dec := func(stmt *sql.Statement) bool {
var nodeID1, nodeID2 types.NodeID
stmt.ColumnBytes(0, nodeID1[:])
stmt.ColumnBytes(1, nodeID2[:])

var id1, id2 types.ATXID
stmt.ColumnBytes(2, id1[:])
stmt.ColumnBytes(3, id2[:])

result = append(result, PrevATXCollision{
NodeID1: nodeID1,
ATX1: id1,

NodeID2: nodeID2,
ATX2: id2,
})
return true
var id types.ATXID
stmt.ColumnBytes(0, id[:])
atxs = append(atxs, id)
return len(atxs) < 2
}
// we are joining the table with itself to find ATXs with the same prevATX
// the WHERE clause ensures that we only get the pairs once
if _, err := db.Exec(`
SELECT p1.pubkey, p2.pubkey, p1.atxid, p2.atxid
FROM posts p1
INNER JOIN posts p2 ON p1.prev_atxid = p2.prev_atxid
WHERE p1.atxid < p2.atxid;`, nil, dec); err != nil {
return nil, fmt.Errorf("error getting ATXs with same prevATX: %w", err)
_, err := db.Exec("SELECT atxid FROM posts WHERE prev_atxid = ?1 AND pubkey = ?2;", enc, dec)
if err != nil {
return types.EmptyATXID, types.EmptyATXID, fmt.Errorf("error getting ATXs with same prevATX: %w", err)
}

return result, nil
if len(atxs) != 2 {
return types.EmptyATXID, types.EmptyATXID, sql.ErrNotFound
}
return atxs[0], atxs[1], nil
}

func Units(db sql.Executor, atxID types.ATXID, nodeID types.NodeID) (uint32, error) {
Expand Down
83 changes: 68 additions & 15 deletions sql/atxs/atxs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,7 @@ func TestLatest(t *testing.T) {
}
}

func Test_PrevATXCollisions(t *testing.T) {
func Test_PrevATXCollision(t *testing.T) {
db := sql.InMemory()
sig, err := signing.NewEdSigner()
require.NoError(t, err)
Expand All @@ -1048,29 +1048,29 @@ func Test_PrevATXCollisions(t *testing.T) {
require.NoError(t, err)
require.Equal(t, atx2, got2)

// add 10 valid ATXs by 10 other smeshers
// add 10 valid ATXs by 10 other smeshers, using the same previous but no collision
var otherIds []types.NodeID
for i := 2; i < 6; i++ {
otherSig, err := signing.NewEdSigner()
require.NoError(t, err)
otherIds = append(otherIds, otherSig.NodeID())

atx, blob := newAtx(t, otherSig, withPublishEpoch(types.EpochID(i)))
require.NoError(t, atxs.Add(db, atx, blob))

atx2, blob2 := newAtx(t, otherSig,
withPublishEpoch(types.EpochID(i+1)),
)
atx2, blob2 := newAtx(t, otherSig, withPublishEpoch(types.EpochID(i+1)))
require.NoError(t, atxs.Add(db, atx2, blob2))
require.NoError(t, atxs.SetPost(db, atx2.ID(), atx.ID(), 0, sig.NodeID(), 10))
require.NoError(t, atxs.SetPost(db, atx2.ID(), prevATXID, 0, atx2.SmesherID, 10))
}

// get the collisions
got, err := atxs.PrevATXCollisions(db)
collision1, collision2, err := atxs.PrevATXCollision(db, prevATXID, sig.NodeID())
require.NoError(t, err)
require.Len(t, got, 1)
require.ElementsMatch(t, []types.ATXID{atx1.ID(), atx2.ID()}, []types.ATXID{collision1, collision2})

require.Equal(t, sig.NodeID(), got[0].NodeID1)
require.Equal(t, sig.NodeID(), got[0].NodeID2)
require.ElementsMatch(t, []types.ATXID{atx1.ID(), atx2.ID()}, []types.ATXID{got[0].ATX1, got[0].ATX2})
_, _, err = atxs.PrevATXCollision(db, types.RandomATXID(), sig.NodeID())
require.ErrorIs(t, err, sql.ErrNotFound)

for _, id := range append(otherIds, types.RandomNodeID()) {
_, _, err := atxs.PrevATXCollision(db, prevATXID, id)
require.ErrorIs(t, err, sql.ErrNotFound)
}
}

func TestCoinbase(t *testing.T) {
Expand Down Expand Up @@ -1362,3 +1362,56 @@ func Test_Previous(t *testing.T) {
require.Equal(t, previousAtxs, got)
})
}

func TestPrevIDByNodeID(t *testing.T) {
t.Run("no previous ATXs", func(t *testing.T) {
db := sql.InMemory()
_, err := atxs.PrevIDByNodeID(db, types.RandomNodeID(), 0)
require.ErrorIs(t, err, sql.ErrNotFound)
})
t.Run("filters by epoch", func(t *testing.T) {
db := sql.InMemory()
sig, err := signing.NewEdSigner()
require.NoError(t, err)

atx1, blob1 := newAtx(t, sig, withPublishEpoch(1))
require.NoError(t, atxs.Add(db, atx1, blob1))
require.NoError(t, atxs.SetPost(db, atx1.ID(), types.EmptyATXID, 0, sig.NodeID(), 4))

atx2, blob2 := newAtx(t, sig, withPublishEpoch(2))
require.NoError(t, atxs.Add(db, atx2, blob2))
require.NoError(t, atxs.SetPost(db, atx2.ID(), types.EmptyATXID, 0, sig.NodeID(), 4))

_, err = atxs.PrevIDByNodeID(db, sig.NodeID(), 1)
require.ErrorIs(t, err, sql.ErrNotFound)

prevID, err := atxs.PrevIDByNodeID(db, sig.NodeID(), 2)
require.NoError(t, err)
require.Equal(t, atx1.ID(), prevID)

prevID, err = atxs.PrevIDByNodeID(db, sig.NodeID(), 3)
require.NoError(t, err)
require.Equal(t, atx2.ID(), prevID)
})
t.Run("the previous is merged and ID is not the signer", func(t *testing.T) {
db := sql.InMemory()
sig, err := signing.NewEdSigner()
require.NoError(t, err)
id := types.RandomNodeID()

atx1, blob1 := newAtx(t, sig, withPublishEpoch(1))
require.NoError(t, atxs.Add(db, atx1, blob1))
require.NoError(t, atxs.SetPost(db, atx1.ID(), types.EmptyATXID, 0, sig.NodeID(), 4))
require.NoError(t, atxs.SetPost(db, atx1.ID(), types.EmptyATXID, 0, id, 8))
require.NoError(t, atxs.SetPost(db, atx1.ID(), types.EmptyATXID, 0, types.RandomNodeID(), 12))

atx2, blob2 := newAtx(t, sig, withPublishEpoch(2))
require.NoError(t, atxs.Add(db, atx2, blob2))
require.NoError(t, atxs.SetPost(db, atx2.ID(), atx1.ID(), 0, sig.NodeID(), 4))
require.NoError(t, atxs.SetPost(db, atx2.ID(), atx1.ID(), 0, types.RandomNodeID(), 12))

prevID, err := atxs.PrevIDByNodeID(db, id, 3)
require.NoError(t, err)
require.Equal(t, atx1.ID(), prevID)
})
}

0 comments on commit d21910b

Please sign in to comment.