Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Merged by Bors] - Support multiple previous ATXs #6024

Closed
wants to merge 10 commits into from
1 change: 1 addition & 0 deletions activation/activation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ func publishAtxV1(
return codec.Decode(got, &watx)
})
require.NoError(tb, atxs.Add(tab.db, toAtx(tb, &watx), watx.Blob()))
require.NoError(tb, atxs.SetPost(tab.db, watx.ID(), watx.PrevATXID, 0, watx.SmesherID, watx.NumUnits))
tab.atxsdata.AddFromAtx(toAtx(tb, &watx), false)
return &watx
}
Expand Down
5 changes: 4 additions & 1 deletion activation/e2e/atx_merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,9 @@ func Test_MarryAndMerge(t *testing.T) {
require.Equal(t, units[i], atxFromDb.NumUnits)
require.Equal(t, signer.NodeID(), atxFromDb.SmesherID)
require.Equal(t, publish, atxFromDb.PublishEpoch)
require.Equal(t, mergedATX2.ID(), atxFromDb.PrevATXID)
prev, err := atxs.Previous(db, atxFromDb.ID())
require.NoError(t, err)
require.Len(t, prev, 1)
require.Equal(t, mergedATX2.ID(), prev[0])
}
}
2 changes: 1 addition & 1 deletion activation/handler_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ func (h *HandlerV1) storeAtx(
if err != nil && !errors.Is(err, sql.ErrObjectExists) {
return fmt.Errorf("add atx to db: %w", err)
}
err = atxs.SetUnits(tx, atx.ID(), atx.SmesherID, watx.NumUnits)
err = atxs.SetPost(tx, atx.ID(), watx.PrevATXID, 0, atx.SmesherID, watx.NumUnits)
if err != nil && !errors.Is(err, sql.ErrObjectExists) {
return fmt.Errorf("set atx units: %w", err)
}
Expand Down
107 changes: 49 additions & 58 deletions activation/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,27 +111,25 @@ func (h *HandlerV2) processATX(
return fmt.Errorf("%w: validating marriages: %w", pubsub.ErrValidationReject, err)
}

parts, err := h.syntacticallyValidateDeps(ctx, watx)
atxData, err := h.syntacticallyValidateDeps(ctx, watx)
if err != nil {
return fmt.Errorf("%w: validating atx %s (deps): %w", pubsub.ErrValidationReject, watx.ID(), err)
}
atxData.marriages = marrying

atx := &types.ActivationTx{
PublishEpoch: watx.PublishEpoch,
MarriageATX: watx.MarriageATX,
Coinbase: watx.Coinbase,
BaseTickHeight: baseTickHeight,
NumUnits: parts.effectiveUnits,
TickCount: parts.ticks,
Weight: parts.weight,
NumUnits: atxData.effectiveUnits,
TickCount: atxData.ticks,
Weight: atxData.weight,
VRFNonce: types.VRFPostIndex(watx.VRFNonce),
SmesherID: watx.SmesherID,
}

if watx.Initial == nil {
// FIXME: update to keep many previous ATXs to support merged ATXs
atx.PrevATXID = watx.PreviousATXs[0]
} else {
if watx.Initial != nil {
atx.CommitmentATX = &watx.Initial.CommitmentATX
}

Expand All @@ -141,12 +139,12 @@ func (h *HandlerV2) processATX(
atx.SetID(watx.ID())
atx.SetReceived(received)

if err := h.storeAtx(ctx, atx, watx, marrying, parts.units); err != nil {
if err := h.storeAtx(ctx, atx, atxData); err != nil {
return fmt.Errorf("cannot store atx %s: %w", atx.ShortString(), err)
}

events.ReportNewActivation(atx)
h.logger.Info("new atx", log.ZContext(ctx), zap.Inline(atx))
h.logger.Debug("new atx", log.ZContext(ctx), zap.Inline(atx))
return err
}

Expand Down Expand Up @@ -434,11 +432,19 @@ func (h *HandlerV2) equivocationSet(atx *wire.ActivationTxV2) ([]types.NodeID, e
return identities.EquivocationSetByMarriageATX(h.cdb, *atx.MarriageATX)
}

type atxParts struct {
type idData struct {
previous types.ATXID
previousIndex int
units uint32
}

type activationTx struct {
*wire.ActivationTxV2
ticks uint64
weight uint64
effectiveUnits uint32
units map[types.NodeID]uint32
ids map[types.NodeID]idData
marriages []marriage
}

type nipostSize struct {
Expand Down Expand Up @@ -496,9 +502,10 @@ func (h *HandlerV2) verifyIncludedIDsUniqueness(atx *wire.ActivationTxV2) error
func (h *HandlerV2) syntacticallyValidateDeps(
ctx context.Context,
atx *wire.ActivationTxV2,
) (*atxParts, error) {
parts := atxParts{
units: make(map[types.NodeID]uint32),
) (*activationTx, error) {
result := activationTx{
ActivationTxV2: atx,
ids: make(map[types.NodeID]idData),
}
if atx.Initial != nil {
if err := h.validateCommitmentAtx(h.goldenATXID, atx.Initial.CommitmentATX, atx.PublishEpoch); err != nil {
Expand Down Expand Up @@ -586,7 +593,7 @@ func (h *HandlerV2) syntacticallyValidateDeps(
nipostSizes[i].ticks = leaves / h.tickSize
}

parts.effectiveUnits, parts.weight, err = nipostSizes.sumUp()
result.effectiveUnits, result.weight, err = nipostSizes.sumUp()
if err != nil {
return nil, err
}
Expand All @@ -597,6 +604,7 @@ func (h *HandlerV2) syntacticallyValidateDeps(
for _, post := range niposts.Posts {
id := equivocationSet[post.MarriageIndex]
var commitment types.ATXID
var previous types.ATXID
if atx.Initial != nil {
commitment = atx.Initial.CommitmentATX
} else {
Expand All @@ -608,6 +616,7 @@ func (h *HandlerV2) syntacticallyValidateDeps(
if id == atx.SmesherID {
smesherCommitment = &commitment
}
previous = previousAtxs[post.PrevATXIndex].ID()
}

err := h.nipostValidator.PostV2(
Expand Down Expand Up @@ -635,7 +644,11 @@ func (h *HandlerV2) syntacticallyValidateDeps(
if err != nil {
return nil, fmt.Errorf("validating post for ID %s: %w", id.ShortString(), err)
}
parts.units[id] = post.NumUnits
result.ids[id] = idData{
previous: previous,
previousIndex: int(post.PrevATXIndex),
units: post.NumUnits,
}
}
}

Expand All @@ -649,42 +662,36 @@ func (h *HandlerV2) syntacticallyValidateDeps(
}
}

parts.ticks = nipostSizes.minTicks()
return &parts, nil
result.ticks = nipostSizes.minTicks()
return &result, nil
}

func (h *HandlerV2) checkMalicious(
ctx context.Context,
tx *sql.Tx,
watx *wire.ActivationTxV2,
marrying []marriage,
ids []types.NodeID,
) error {
malicious, err := identities.IsMalicious(tx, watx.SmesherID)
func (h *HandlerV2) checkMalicious(ctx context.Context, tx *sql.Tx, atx *activationTx) error {
malicious, err := identities.IsMalicious(tx, atx.SmesherID)
if err != nil {
return fmt.Errorf("checking if node is malicious: %w", err)
}
if malicious {
return nil
}

malicious, err = h.checkDoubleMarry(ctx, tx, watx, marrying)
malicious, err = h.checkDoubleMarry(ctx, tx, atx)
if err != nil {
return fmt.Errorf("checking double marry: %w", err)
}
if malicious {
return nil
}

malicious, err = h.checkDoublePost(ctx, tx, watx, ids)
malicious, err = h.checkDoublePost(ctx, tx, atx)
if err != nil {
return fmt.Errorf("checking double post: %w", err)
}
if malicious {
return nil
}

malicious, err = h.checkDoubleMerge(ctx, tx, watx)
malicious, err = h.checkDoubleMerge(ctx, tx, atx)
if err != nil {
return fmt.Errorf("checking double merge: %w", err)
}
Expand All @@ -700,13 +707,8 @@ func (h *HandlerV2) checkMalicious(
return nil
}

func (h *HandlerV2) checkDoubleMarry(
ctx context.Context,
tx *sql.Tx,
atx *wire.ActivationTxV2,
marrying []marriage,
) (bool, error) {
for _, m := range marrying {
func (h *HandlerV2) checkDoubleMarry(ctx context.Context, tx *sql.Tx, atx *activationTx) (bool, error) {
for _, m := range atx.marriages {
mATX, err := identities.MarriageATX(tx, m.id)
if err != nil {
return false, fmt.Errorf("checking if ID is married: %w", err)
Expand All @@ -725,7 +727,7 @@ func (h *HandlerV2) checkDoubleMarry(
var otherAtx wire.ActivationTxV2
codec.MustDecode(blob.Bytes, &otherAtx)

proof, err := wire.NewDoubleMarryProof(tx, atx, &otherAtx, m.id)
proof, err := wire.NewDoubleMarryProof(tx, atx.ActivationTxV2, &otherAtx, m.id)
if err != nil {
return true, fmt.Errorf("creating double marry proof: %w", err)
}
Expand All @@ -735,13 +737,8 @@ func (h *HandlerV2) checkDoubleMarry(
return false, nil
}

func (h *HandlerV2) checkDoublePost(
ctx context.Context,
tx *sql.Tx,
atx *wire.ActivationTxV2,
ids []types.NodeID,
) (bool, error) {
for _, id := range ids {
func (h *HandlerV2) checkDoublePost(ctx context.Context, tx *sql.Tx, atx *activationTx) (bool, error) {
for id := range atx.ids {
atxids, err := atxs.FindDoublePublish(tx, id, atx.PublishEpoch)
switch {
case errors.Is(err, sql.ErrNotFound):
Expand All @@ -765,7 +762,7 @@ func (h *HandlerV2) checkDoublePost(
return false, nil
}

func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, watx *wire.ActivationTxV2) (bool, error) {
func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, watx *activationTx) (bool, error) {
if watx.MarriageATX == nil {
return false, nil
}
Expand All @@ -791,20 +788,14 @@ func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, watx *wire
}

// Store an ATX in the DB.
func (h *HandlerV2) storeAtx(
ctx context.Context,
atx *types.ActivationTx,
watx *wire.ActivationTxV2,
marrying []marriage,
units map[types.NodeID]uint32,
) error {
func (h *HandlerV2) storeAtx(ctx context.Context, atx *types.ActivationTx, watx *activationTx) error {
if err := h.cdb.WithTx(ctx, func(tx *sql.Tx) error {
if len(marrying) != 0 {
if len(watx.marriages) != 0 {
marriageData := identities.MarriageData{
ATX: atx.ID(),
Target: atx.SmesherID,
}
for i, m := range marrying {
for i, m := range watx.marriages {
marriageData.Signature = m.signature
marriageData.Index = i
if err := identities.SetMarriage(tx, m.id, &marriageData); err != nil {
Expand All @@ -817,8 +808,8 @@ func (h *HandlerV2) storeAtx(
if err != nil && !errors.Is(err, sql.ErrObjectExists) {
return fmt.Errorf("add atx to db: %w", err)
}
for id, units := range units {
err = atxs.SetUnits(tx, atx.ID(), id, units)
for id, post := range watx.ids {
err = atxs.SetPost(tx, atx.ID(), post.previous, post.previousIndex, id, post.units)
if err != nil && !errors.Is(err, sql.ErrObjectExists) {
return fmt.Errorf("setting atx units for ID %s: %w", id, err)
}
Expand All @@ -837,7 +828,7 @@ func (h *HandlerV2) storeAtx(
// TODO(mafa): don't store own ATX if it would mark the node as malicious
// this probably needs to be done by validating and storing own ATXs eagerly and skipping validation in
// the gossip handler (not sync!)
err := h.checkMalicious(ctx, tx, watx, marrying, maps.Keys(units))
err := h.checkMalicious(ctx, tx, watx)
if err != nil {
return fmt.Errorf("check malicious: %w", err)
}
Expand Down
8 changes: 4 additions & 4 deletions activation/handler_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1435,7 +1435,7 @@ func Test_ValidatePreviousATX(t *testing.T) {
t.Parallel()
prev := &types.ActivationTx{}
prev.SetID(types.RandomATXID())
require.NoError(t, atxs.SetUnits(atxHandler.cdb, prev.ID(), types.RandomNodeID(), 13))
require.NoError(t, atxs.SetPost(atxHandler.cdb, prev.ID(), types.EmptyATXID, 0, types.RandomNodeID(), 13))

_, err := atxHandler.validatePreviousAtx(types.RandomNodeID(), &wire.SubPostV2{}, []*types.ActivationTx{prev})
require.Error(t, err)
Expand All @@ -1446,8 +1446,8 @@ func Test_ValidatePreviousATX(t *testing.T) {
other := types.RandomNodeID()
prev := &types.ActivationTx{}
prev.SetID(types.RandomATXID())
require.NoError(t, atxs.SetUnits(atxHandler.cdb, prev.ID(), id, 7))
require.NoError(t, atxs.SetUnits(atxHandler.cdb, prev.ID(), other, 13))
require.NoError(t, atxs.SetPost(atxHandler.cdb, prev.ID(), types.EmptyATXID, 0, id, 7))
require.NoError(t, atxs.SetPost(atxHandler.cdb, prev.ID(), types.EmptyATXID, 0, other, 13))

units, err := atxHandler.validatePreviousAtx(id, &wire.SubPostV2{NumUnits: 100}, []*types.ActivationTx{prev})
require.NoError(t, err)
Expand All @@ -1467,7 +1467,7 @@ func Test_ValidatePreviousATX(t *testing.T) {
other := types.RandomNodeID()
prev := &types.ActivationTx{}
prev.SetID(types.RandomATXID())
require.NoError(t, atxs.SetUnits(atxHandler.cdb, prev.ID(), other, 13))
require.NoError(t, atxs.SetPost(atxHandler.cdb, prev.ID(), types.EmptyATXID, 0, other, 13))

_, err := atxHandler.validatePreviousAtx(id, &wire.SubPostV2{NumUnits: 100}, []*types.ActivationTx{prev})
require.Error(t, err)
Expand Down
Loading
Loading