From 430408c0ac2606a451128cca0437e0ff942a6dc2 Mon Sep 17 00:00:00 2001 From: k <30611210+countvonzero@users.noreply.github.com> Date: Thu, 7 Sep 2023 21:59:30 +0000 Subject: [PATCH] make sure ref ballot has epoch data (#4974) ## Motivation part of #4903 ## Changes make sure to return the ballot with valid epoch data for the ref ballot --- miner/oracle.go | 19 ++-------- miner/proposal_builder.go | 6 ++-- miner/proposal_builder_test.go | 5 ++- sql/ballots/ballots.go | 66 +++++++++++++++++++++++++--------- sql/ballots/ballots_test.go | 51 +++++++++++++------------- 5 files changed, 86 insertions(+), 61 deletions(-) diff --git a/miner/oracle.go b/miner/oracle.go index 9b12cb2656..79c528d4dd 100644 --- a/miner/oracle.go +++ b/miner/oracle.go @@ -204,26 +204,11 @@ func (o *Oracle) activeSet(targetEpoch types.EpochID) (uint64, uint64, types.ATX return ownWeight, totalWeight, ownAtx, atxids, nil } -func refBallot(db sql.Executor, epoch types.EpochID, nodeID types.NodeID) (*types.Ballot, error) { - ref, err := ballots.GetRefBallot(db, epoch, nodeID) - if errors.Is(err, sql.ErrNotFound) { - return nil, nil - } - if err != nil { - return nil, fmt.Errorf("miner get refballot: %w", err) - } - ballot, err := ballots.Get(db, ref) - if err != nil { - return nil, fmt.Errorf("miner get ballot: %w", err) - } - return ballot, nil -} - // calcEligibilityProofs calculates the eligibility proofs of proposals for the miner in the given epoch // and returns the proofs along with the epoch's active set. func (o *Oracle) calcEligibilityProofs(lid types.LayerID, epoch types.EpochID, beacon types.Beacon, nonce types.VRFPostIndex) (*EpochEligibility, error) { - ref, err := refBallot(o.cdb, epoch, o.vrfSigner.NodeID()) - if err != nil { + ref, err := ballots.RefBallot(o.cdb, epoch, o.vrfSigner.NodeID()) + if err != nil && !errors.Is(err, sql.ErrNotFound) { return nil, err } diff --git a/miner/proposal_builder.go b/miner/proposal_builder.go index c83ffe1e03..ceeb7d976d 100644 --- a/miner/proposal_builder.go +++ b/miner/proposal_builder.go @@ -240,7 +240,7 @@ func (pb *ProposalBuilder) createProposal( } epoch := layerID.GetEpoch() - refBallot, err := ballots.GetRefBallot(pb.cdb, epoch, pb.signer.NodeID()) + ref, err := ballots.FirstInEpoch(pb.cdb, epochEligibility.Atx, epoch) if err != nil { if !errors.Is(err, sql.ErrNotFound) { return nil, fmt.Errorf("get ref ballot: %w", err) @@ -261,9 +261,9 @@ func (pb *ProposalBuilder) createProposal( pb.logger.With().Debug("creating ballot with reference ballot (no active set)", log.Context(ctx), layerID, - log.Named("ref_ballot", refBallot), + log.Stringer("ref_ballot", ref.ID()), ) - ib.RefBallot = refBallot + ib.RefBallot = ref.ID() } p := &types.Proposal{ diff --git a/miner/proposal_builder_test.go b/miner/proposal_builder_test.go index 1b752466b6..8a9c30169e 100644 --- a/miner/proposal_builder_test.go +++ b/miner/proposal_builder_test.go @@ -411,7 +411,10 @@ func TestBuilder_HandleLayer_RefBallot(t *testing.T) { b := createBuilder(t) layerID := types.LayerID(layersPerEpoch * 3).Add(1) + atx := types.ATXID{1, 2, 3} refBallot := types.NewExistingBallot(types.BallotID{1}, types.EmptyEdSignature, b.ProposalBuilder.signer.NodeID(), layerID.Sub(1)) + refBallot.EpochData = &types.EpochData{} + refBallot.AtxID = atx require.NoError(t, ballots.Add(b.cdb, &refBallot)) beacon := types.RandomBeacon() sig, err := signing.NewEdSigner() @@ -423,7 +426,7 @@ func TestBuilder_HandleLayer_RefBallot(t *testing.T) { b.mBeacon.EXPECT().GetBeacon(gomock.Any()).Return(beacon, nil) b.mNonce.EXPECT().VRFNonce(gomock.Any(), gomock.Any()).Return(nonce, nil) ee := &EpochEligibility{ - Atx: types.RandomATXID(), + Atx: atx, ActiveSet: genActiveSet(t), Proofs: map[types.LayerID][]types.VotingEligibility{layerID: genProofs(t, 1)}, Slots: 4, diff --git a/sql/ballots/ballots.go b/sql/ballots/ballots.go index 28bdfd729c..827a29df95 100644 --- a/sql/ballots/ballots.go +++ b/sql/ballots/ballots.go @@ -185,29 +185,54 @@ func LayerBallotByNodeID(db sql.Executor, lid types.LayerID, nodeID types.NodeID return &ballot, err } -// GetRefBallot gets a ref ballot for a layer and a nodeID. -func GetRefBallot(db sql.Executor, epochID types.EpochID, nodeID types.NodeID) (ballotID types.BallotID, err error) { - firstLayer := epochID.FirstLayer() +// RefBallot gets a ref ballot for a layer and a nodeID. +func RefBallot(db sql.Executor, epoch types.EpochID, nodeID types.NodeID) (*types.Ballot, error) { + firstLayer := epoch.FirstLayer() lastLayer := firstLayer.Add(types.GetLayersPerEpoch()).Sub(1) - rows, err := db.Exec(` - select id from ballots + var ( + bid types.BallotID + ballot types.Ballot + rows, n int + err error + ) + dec := func(stmt *sql.Statement) bool { + stmt.ColumnBytes(0, bid[:]) + if n, err = codec.DecodeFrom(stmt.ColumnReader(1), &ballot); err != nil { + if err != io.EOF { + err = fmt.Errorf("ref ballot %s/%d: %w", nodeID.ShortString(), epoch, err) + return false + } + } else if n == 0 { + err = fmt.Errorf("ref ballot missing data %s/%d", nodeID.ShortString(), epoch) + return false + } + ballot.SetID(bid) + ballot.SmesherID = nodeID + if stmt.ColumnInt(2) > 0 { + ballot.SetMalicious() + } + // only ref ballot has valid EpochData + if ballot.EpochData != nil { + return false + } + return true + } + rows, err = db.Exec(` + select id, ballot, length(identities.proof) from ballots + left join identities using(pubkey) where layer between ?1 and ?2 and pubkey = ?3 - order by layer - limit 1;`, + order by layer asc;`, func(stmt *sql.Statement) { stmt.BindInt64(1, int64(firstLayer)) stmt.BindInt64(2, int64(lastLayer)) stmt.BindBytes(3, nodeID.Bytes()) - }, func(stmt *sql.Statement) bool { - stmt.ColumnBytes(0, ballotID[:]) - return true - }) + }, dec) if err != nil { - return types.BallotID{}, fmt.Errorf("ref ballot epoch %v: %w", epochID, err) + return nil, fmt.Errorf("ref ballot %s/%d: %w", nodeID.ShortString(), epoch, err) } else if rows == 0 { - return types.BallotID{}, fmt.Errorf("%w ref ballot epoch %s", sql.ErrNotFound, epochID) + return nil, fmt.Errorf("%w ref ballot %s/%d", sql.ErrNotFound, nodeID.ShortString(), epoch) } - return ballotID, nil + return &ballot, nil } // LatestLayer gets the highest layer with ballots. @@ -251,11 +276,20 @@ func FirstInEpoch(db sql.Executor, atx types.ATXID, epoch types.EpochID) (*types } ballot.SetID(bid) ballot.SmesherID = nodeID + if stmt.ColumnInt(3) > 0 { + ballot.SetMalicious() + } + // only ref ballot has valid EpochData + if ballot.EpochData != nil { + return false + } return true } rows, err = db.Exec(` - select id, pubkey, ballot from ballots where atx = ?1 and layer between ?2 and ?3 - order by layer asc, id asc limit 1;`, enc, dec) + select id, pubkey, ballot, length(identities.proof) from ballots + left join identities using(pubkey) + where atx = ?1 and layer between ?2 and ?3 + order by layer asc;`, enc, dec) if err != nil { return nil, fmt.Errorf("ballot by atx %s: %w", atx, err) } diff --git a/sql/ballots/ballots_test.go b/sql/ballots/ballots_test.go index 1f742aa141..88dd3b8b77 100644 --- a/sql/ballots/ballots_test.go +++ b/sql/ballots/ballots_test.go @@ -153,42 +153,36 @@ func TestLayerBallotBySmesher(t *testing.T) { require.Equal(t, ballots[1], *prev) } -func TestGetRefBallot(t *testing.T) { +func TestRefBallot(t *testing.T) { db := sql.InMemory() - lid2 := types.LayerID(2) - lid3 := types.LayerID(3) - lid4 := types.LayerID(4) - lid5 := types.LayerID(5) - lid6 := types.LayerID(6) - nodeID1 := types.RandomNodeID() - nodeID2 := types.RandomNodeID() - nodeID3 := types.RandomNodeID() - nodeID4 := types.RandomNodeID() ballots := []types.Ballot{ - types.NewExistingBallot(types.BallotID{1}, types.EmptyEdSignature, nodeID1, lid2), - types.NewExistingBallot(types.BallotID{2}, types.EmptyEdSignature, nodeID1, lid3), - types.NewExistingBallot(types.BallotID{3}, types.EmptyEdSignature, nodeID2, lid3), - types.NewExistingBallot(types.BallotID{4}, types.EmptyEdSignature, nodeID2, lid4), - types.NewExistingBallot(types.BallotID{5}, types.EmptyEdSignature, nodeID3, lid5), - types.NewExistingBallot(types.BallotID{6}, types.EmptyEdSignature, nodeID4, lid6), + types.NewExistingBallot(types.BallotID{1}, types.EmptyEdSignature, types.NodeID{1}, types.EpochID(1).FirstLayer()-1), + types.NewExistingBallot(types.BallotID{2}, types.EmptyEdSignature, types.NodeID{1}, types.EpochID(1).FirstLayer()), + types.NewExistingBallot(types.BallotID{3}, types.EmptyEdSignature, types.NodeID{2}, types.EpochID(1).FirstLayer()), + types.NewExistingBallot(types.BallotID{4}, types.EmptyEdSignature, types.NodeID{2}, types.EpochID(1).FirstLayer()+1), + types.NewExistingBallot(types.BallotID{5}, types.EmptyEdSignature, types.NodeID{3}, types.EpochID(1).FirstLayer()+2), + types.NewExistingBallot(types.BallotID{6}, types.EmptyEdSignature, types.NodeID{4}, types.EpochID(2).FirstLayer()), + } + for _, i := range []int{0, 1, 2, 4, 5} { + ballots[i].EpochData = &types.EpochData{Beacon: types.Beacon{1, 2}} } for _, ballot := range ballots { require.NoError(t, Add(db, &ballot)) } - count, err := GetRefBallot(db, 1, nodeID1) + got, err := RefBallot(db, 1, types.NodeID{1}) require.NoError(t, err) - require.Equal(t, types.BallotID{2}, count) + require.Equal(t, ballots[1], *got) - count, err = GetRefBallot(db, 1, nodeID2) + got, err = RefBallot(db, 1, types.NodeID{2}) require.NoError(t, err) - require.Equal(t, types.BallotID{3}, count) + require.Equal(t, ballots[2], *got) - count, err = GetRefBallot(db, 1, nodeID3) + got, err = RefBallot(db, 1, types.NodeID{3}) require.NoError(t, err) - require.Equal(t, types.BallotID{5}, count) + require.Equal(t, ballots[4], *got) - _, err = GetRefBallot(db, 1, nodeID4) + _, err = RefBallot(db, 1, types.NodeID{4}) require.ErrorIs(t, err, sql.ErrNotFound) } @@ -229,6 +223,7 @@ func TestFirstInEpoch(t *testing.T) { require.NoError(t, Add(db, &b1)) b2 := types.NewExistingBallot(types.BallotID{2}, types.EmptyEdSignature, sig.NodeID(), lid) b2.AtxID = atx.ID() + b2.EpochData = &types.EpochData{} require.NoError(t, Add(db, &b2)) b3 := types.NewExistingBallot(types.BallotID{3}, types.EmptyEdSignature, sig.NodeID(), lid.Add(1)) b3.AtxID = atx.ID() @@ -236,8 +231,16 @@ func TestFirstInEpoch(t *testing.T) { got, err = FirstInEpoch(db, atx.ID(), 2) require.NoError(t, err) + require.False(t, got.IsMalicious()) + require.Equal(t, got.AtxID, atx.ID()) + require.Equal(t, got.ID(), b2.ID()) + + require.NoError(t, identities.SetMalicious(db, sig.NodeID(), []byte("bad"), time.Now())) + got, err = FirstInEpoch(db, atx.ID(), 2) + require.NoError(t, err) + require.True(t, got.IsMalicious()) require.Equal(t, got.AtxID, atx.ID()) - require.Equal(t, got.ID(), b1.ID()) + require.Equal(t, got.ID(), b2.ID()) } func TestAllFirstInEpoch(t *testing.T) {