Skip to content

Commit

Permalink
make sure ref ballot has epoch data
Browse files Browse the repository at this point in the history
  • Loading branch information
countvonzero committed Sep 7, 2023
1 parent 35f57da commit 4955f27
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 61 deletions.
19 changes: 2 additions & 17 deletions miner/oracle.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,26 +204,11 @@ func (o *Oracle) activeSet(targetEpoch types.EpochID) (uint64, uint64, types.ATX
return ownWeight, totalWeight, ownAtx, atxids, nil
}

func refBallot(db sql.Executor, epoch types.EpochID, nodeID types.NodeID) (*types.Ballot, error) {
ref, err := ballots.GetRefBallot(db, epoch, nodeID)
if errors.Is(err, sql.ErrNotFound) {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("miner get refballot: %w", err)
}
ballot, err := ballots.Get(db, ref)
if err != nil {
return nil, fmt.Errorf("miner get ballot: %w", err)
}
return ballot, nil
}

// calcEligibilityProofs calculates the eligibility proofs of proposals for the miner in the given epoch
// and returns the proofs along with the epoch's active set.
func (o *Oracle) calcEligibilityProofs(lid types.LayerID, epoch types.EpochID, beacon types.Beacon, nonce types.VRFPostIndex) (*EpochEligibility, error) {
ref, err := refBallot(o.cdb, epoch, o.vrfSigner.NodeID())
if err != nil {
ref, err := ballots.RefBallot(o.cdb, epoch, o.vrfSigner.NodeID())
if err != nil && !errors.Is(err, sql.ErrNotFound) {
return nil, err
}

Expand Down
6 changes: 3 additions & 3 deletions miner/proposal_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ func (pb *ProposalBuilder) createProposal(
}

epoch := layerID.GetEpoch()
refBallot, err := ballots.GetRefBallot(pb.cdb, epoch, pb.signer.NodeID())
ref, err := ballots.FirstInEpoch(pb.cdb, epochEligibility.Atx, epoch)
if err != nil {
if !errors.Is(err, sql.ErrNotFound) {
return nil, fmt.Errorf("get ref ballot: %w", err)
Expand All @@ -261,9 +261,9 @@ func (pb *ProposalBuilder) createProposal(
pb.logger.With().Debug("creating ballot with reference ballot (no active set)",
log.Context(ctx),
layerID,
log.Named("ref_ballot", refBallot),
log.Stringer("ref_ballot", ref.ID()),
)
ib.RefBallot = refBallot
ib.RefBallot = ref.ID()
}

p := &types.Proposal{
Expand Down
5 changes: 4 additions & 1 deletion miner/proposal_builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,10 @@ func TestBuilder_HandleLayer_RefBallot(t *testing.T) {
b := createBuilder(t)

layerID := types.LayerID(layersPerEpoch * 3).Add(1)
atx := types.ATXID{1, 2, 3}
refBallot := types.NewExistingBallot(types.BallotID{1}, types.EmptyEdSignature, b.ProposalBuilder.signer.NodeID(), layerID.Sub(1))
refBallot.EpochData = &types.EpochData{}
refBallot.AtxID = atx
require.NoError(t, ballots.Add(b.cdb, &refBallot))
beacon := types.RandomBeacon()
sig, err := signing.NewEdSigner()
Expand All @@ -423,7 +426,7 @@ func TestBuilder_HandleLayer_RefBallot(t *testing.T) {
b.mBeacon.EXPECT().GetBeacon(gomock.Any()).Return(beacon, nil)
b.mNonce.EXPECT().VRFNonce(gomock.Any(), gomock.Any()).Return(nonce, nil)
ee := &EpochEligibility{
Atx: types.RandomATXID(),
Atx: atx,
ActiveSet: genActiveSet(t),
Proofs: map[types.LayerID][]types.VotingEligibility{layerID: genProofs(t, 1)},
Slots: 4,
Expand Down
66 changes: 50 additions & 16 deletions sql/ballots/ballots.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,29 +185,54 @@ func LayerBallotByNodeID(db sql.Executor, lid types.LayerID, nodeID types.NodeID
return &ballot, err
}

// GetRefBallot gets a ref ballot for a layer and a nodeID.
func GetRefBallot(db sql.Executor, epochID types.EpochID, nodeID types.NodeID) (ballotID types.BallotID, err error) {
firstLayer := epochID.FirstLayer()
// RefBallot gets a ref ballot for a layer and a nodeID.
func RefBallot(db sql.Executor, epoch types.EpochID, nodeID types.NodeID) (*types.Ballot, error) {
firstLayer := epoch.FirstLayer()
lastLayer := firstLayer.Add(types.GetLayersPerEpoch()).Sub(1)
rows, err := db.Exec(`
select id from ballots
var (
bid types.BallotID
ballot types.Ballot
rows, n int
err error
)
dec := func(stmt *sql.Statement) bool {
stmt.ColumnBytes(0, bid[:])
if n, err = codec.DecodeFrom(stmt.ColumnReader(1), &ballot); err != nil {
if err != io.EOF {
err = fmt.Errorf("ref ballot %s/%d: %w", nodeID.ShortString(), epoch, err)
return false
}

Check warning on line 204 in sql/ballots/ballots.go

View check run for this annotation

Codecov / codecov/patch

sql/ballots/ballots.go#L201-L204

Added lines #L201 - L204 were not covered by tests
} else if n == 0 {
err = fmt.Errorf("ref ballot missing data %s/%d", nodeID.ShortString(), epoch)
return false
}

Check warning on line 208 in sql/ballots/ballots.go

View check run for this annotation

Codecov / codecov/patch

sql/ballots/ballots.go#L206-L208

Added lines #L206 - L208 were not covered by tests
ballot.SetID(bid)
ballot.SmesherID = nodeID
if stmt.ColumnInt(2) > 0 {
ballot.SetMalicious()
}

Check warning on line 213 in sql/ballots/ballots.go

View check run for this annotation

Codecov / codecov/patch

sql/ballots/ballots.go#L212-L213

Added lines #L212 - L213 were not covered by tests
// only ref ballot has valid EpochData
if ballot.EpochData != nil {
return false
}
return true

Check warning on line 218 in sql/ballots/ballots.go

View check run for this annotation

Codecov / codecov/patch

sql/ballots/ballots.go#L218

Added line #L218 was not covered by tests
}
rows, err = db.Exec(`
select id, ballot, length(identities.proof) from ballots
left join identities using(pubkey)
where layer between ?1 and ?2 and pubkey = ?3
order by layer
limit 1;`,
order by layer asc;`,
func(stmt *sql.Statement) {
stmt.BindInt64(1, int64(firstLayer))
stmt.BindInt64(2, int64(lastLayer))
stmt.BindBytes(3, nodeID.Bytes())
}, func(stmt *sql.Statement) bool {
stmt.ColumnBytes(0, ballotID[:])
return true
})
}, dec)
if err != nil {
return types.BallotID{}, fmt.Errorf("ref ballot epoch %v: %w", epochID, err)
return nil, fmt.Errorf("ref ballot %s/%d: %w", nodeID.ShortString(), epoch, err)

Check warning on line 231 in sql/ballots/ballots.go

View check run for this annotation

Codecov / codecov/patch

sql/ballots/ballots.go#L231

Added line #L231 was not covered by tests
} else if rows == 0 {
return types.BallotID{}, fmt.Errorf("%w ref ballot epoch %s", sql.ErrNotFound, epochID)
return nil, fmt.Errorf("%w ref ballot %s/%d", sql.ErrNotFound, nodeID.ShortString(), epoch)
}
return ballotID, nil
return &ballot, nil
}

// LatestLayer gets the highest layer with ballots.
Expand Down Expand Up @@ -251,11 +276,20 @@ func FirstInEpoch(db sql.Executor, atx types.ATXID, epoch types.EpochID) (*types
}
ballot.SetID(bid)
ballot.SmesherID = nodeID
if stmt.ColumnInt(3) > 0 {
ballot.SetMalicious()
}
// only ref ballot has valid EpochData
if ballot.EpochData != nil {
return false
}
return true
}
rows, err = db.Exec(`
select id, pubkey, ballot from ballots where atx = ?1 and layer between ?2 and ?3
order by layer asc, id asc limit 1;`, enc, dec)
select id, pubkey, ballot, length(identities.proof) from ballots
left join identities using(pubkey)
where atx = ?1 and layer between ?2 and ?3
order by layer asc;`, enc, dec)
if err != nil {
return nil, fmt.Errorf("ballot by atx %s: %w", atx, err)
}
Expand Down
51 changes: 27 additions & 24 deletions sql/ballots/ballots_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,42 +153,36 @@ func TestLayerBallotBySmesher(t *testing.T) {
require.Equal(t, ballots[1], *prev)
}

func TestGetRefBallot(t *testing.T) {
func TestRefBallot(t *testing.T) {
db := sql.InMemory()
lid2 := types.LayerID(2)
lid3 := types.LayerID(3)
lid4 := types.LayerID(4)
lid5 := types.LayerID(5)
lid6 := types.LayerID(6)
nodeID1 := types.RandomNodeID()
nodeID2 := types.RandomNodeID()
nodeID3 := types.RandomNodeID()
nodeID4 := types.RandomNodeID()
ballots := []types.Ballot{
types.NewExistingBallot(types.BallotID{1}, types.EmptyEdSignature, nodeID1, lid2),
types.NewExistingBallot(types.BallotID{2}, types.EmptyEdSignature, nodeID1, lid3),
types.NewExistingBallot(types.BallotID{3}, types.EmptyEdSignature, nodeID2, lid3),
types.NewExistingBallot(types.BallotID{4}, types.EmptyEdSignature, nodeID2, lid4),
types.NewExistingBallot(types.BallotID{5}, types.EmptyEdSignature, nodeID3, lid5),
types.NewExistingBallot(types.BallotID{6}, types.EmptyEdSignature, nodeID4, lid6),
types.NewExistingBallot(types.BallotID{1}, types.EmptyEdSignature, types.NodeID{1}, types.EpochID(1).FirstLayer()-1),
types.NewExistingBallot(types.BallotID{2}, types.EmptyEdSignature, types.NodeID{1}, types.EpochID(1).FirstLayer()),
types.NewExistingBallot(types.BallotID{3}, types.EmptyEdSignature, types.NodeID{2}, types.EpochID(1).FirstLayer()),
types.NewExistingBallot(types.BallotID{4}, types.EmptyEdSignature, types.NodeID{2}, types.EpochID(1).FirstLayer()+1),
types.NewExistingBallot(types.BallotID{5}, types.EmptyEdSignature, types.NodeID{3}, types.EpochID(1).FirstLayer()+2),
types.NewExistingBallot(types.BallotID{6}, types.EmptyEdSignature, types.NodeID{4}, types.EpochID(2).FirstLayer()),
}
for _, i := range []int{0, 1, 2, 4, 5} {
ballots[i].EpochData = &types.EpochData{Beacon: types.Beacon{1, 2}}
}
for _, ballot := range ballots {
require.NoError(t, Add(db, &ballot))
}

count, err := GetRefBallot(db, 1, nodeID1)
got, err := RefBallot(db, 1, types.NodeID{1})
require.NoError(t, err)
require.Equal(t, types.BallotID{2}, count)
require.Equal(t, ballots[1], *got)

count, err = GetRefBallot(db, 1, nodeID2)
got, err = RefBallot(db, 1, types.NodeID{2})
require.NoError(t, err)
require.Equal(t, types.BallotID{3}, count)
require.Equal(t, ballots[2], *got)

count, err = GetRefBallot(db, 1, nodeID3)
got, err = RefBallot(db, 1, types.NodeID{3})
require.NoError(t, err)
require.Equal(t, types.BallotID{5}, count)
require.Equal(t, ballots[4], *got)

_, err = GetRefBallot(db, 1, nodeID4)
_, err = RefBallot(db, 1, types.NodeID{4})
require.ErrorIs(t, err, sql.ErrNotFound)
}

Expand Down Expand Up @@ -229,15 +223,24 @@ func TestFirstInEpoch(t *testing.T) {
require.NoError(t, Add(db, &b1))
b2 := types.NewExistingBallot(types.BallotID{2}, types.EmptyEdSignature, sig.NodeID(), lid)
b2.AtxID = atx.ID()
b2.EpochData = &types.EpochData{}
require.NoError(t, Add(db, &b2))
b3 := types.NewExistingBallot(types.BallotID{3}, types.EmptyEdSignature, sig.NodeID(), lid.Add(1))
b3.AtxID = atx.ID()
require.NoError(t, Add(db, &b3))

got, err = FirstInEpoch(db, atx.ID(), 2)
require.NoError(t, err)
require.False(t, got.IsMalicious())
require.Equal(t, got.AtxID, atx.ID())
require.Equal(t, got.ID(), b2.ID())

require.NoError(t, identities.SetMalicious(db, sig.NodeID(), []byte("bad"), time.Now()))
got, err = FirstInEpoch(db, atx.ID(), 2)
require.NoError(t, err)
require.True(t, got.IsMalicious())
require.Equal(t, got.AtxID, atx.ID())
require.Equal(t, got.ID(), b1.ID())
require.Equal(t, got.ID(), b2.ID())
}

func TestAllFirstInEpoch(t *testing.T) {
Expand Down

0 comments on commit 4955f27

Please sign in to comment.