Skip to content

Commit

Permalink
Pass blob explictly to atxs.Add() (#6093)
Browse files Browse the repository at this point in the history
## Motivation
  • Loading branch information
poszu committed Jul 8, 2024
1 parent 7251d96 commit 5dbae68
Show file tree
Hide file tree
Showing 47 changed files with 342 additions and 359 deletions.
2 changes: 1 addition & 1 deletion activation/activation_multi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func TestRegossip(t *testing.T) {
atx.PublishEpoch = layer.GetEpoch()
atx.Sign(sig)
vAtx := toAtx(t, atx)
require.NoError(t, atxs.Add(tab.db, vAtx))
require.NoError(t, atxs.Add(tab.db, vAtx, atx.Blob()))

if refAtx == nil {
refAtx = vAtx
Expand Down
32 changes: 16 additions & 16 deletions activation/activation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func publishAtxV1(
func(_ context.Context, _ string, got []byte) error {
return codec.Decode(got, &watx)
})
require.NoError(tb, atxs.Add(tab.db, toAtx(tb, &watx)))
require.NoError(tb, atxs.Add(tab.db, toAtx(tb, &watx), watx.Blob()))
tab.atxsdata.AddFromAtx(toAtx(tb, &watx), false)
return &watx
}
Expand Down Expand Up @@ -351,7 +351,7 @@ func TestBuilder_PublishActivationTx_HappyFlow(t *testing.T) {
currLayer := posEpoch.FirstLayer()
prevAtx := newInitialATXv1(t, tab.goldenATXID)
prevAtx.Sign(sig)
require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx)))
require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx), prevAtx.Blob()))
tab.atxsdata.AddFromAtx(toAtx(t, prevAtx), false)

// create and publish ATX
Expand Down Expand Up @@ -387,7 +387,7 @@ func TestBuilder_Loop_WaitsOnStaleChallenge(t *testing.T) {
currLayer := (postGenesisEpoch + 1).FirstLayer()
prevAtx := newInitialATXv1(t, tab.goldenATXID)
prevAtx.Sign(sig)
require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx)))
require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx), prevAtx.Blob()))
tab.atxsdata.AddFromAtx(toAtx(t, prevAtx), false)

tab.mclock.EXPECT().CurrentLayer().Return(currLayer).AnyTimes()
Expand Down Expand Up @@ -436,7 +436,7 @@ func TestBuilder_PublishActivationTx_FaultyNet(t *testing.T) {
currLayer := postGenesisEpoch.FirstLayer()
prevAtx := newInitialATXv1(t, tab.goldenATXID)
prevAtx.Sign(sig)
require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx)))
require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx), prevAtx.Blob()))
tab.atxsdata.AddFromAtx(toAtx(t, prevAtx), false)

publishEpoch := posEpoch + 1
Expand Down Expand Up @@ -511,7 +511,7 @@ func TestBuilder_PublishActivationTx_UsesExistingChallengeOnLatePublish(t *testi
prevAtx := newInitialATXv1(t, tab.goldenATXID)
prevAtx.Sign(sig)
vPrevAtx := toAtx(t, prevAtx)
require.NoError(t, atxs.Add(tab.db, vPrevAtx))
require.NoError(t, atxs.Add(tab.db, vPrevAtx, prevAtx.Blob()))
tab.atxsdata.AddFromAtx(toAtx(t, prevAtx), false)

publishEpoch := currLayer.GetEpoch()
Expand Down Expand Up @@ -588,7 +588,7 @@ func TestBuilder_PublishActivationTx_RebuildNIPostWhenTargetEpochPassed(t *testi
prevAtx := newInitialATXv1(t, tab.goldenATXID)
prevAtx.Sign(sig)
vPrevAtx := toAtx(t, prevAtx)
require.NoError(t, atxs.Add(tab.db, vPrevAtx))
require.NoError(t, atxs.Add(tab.db, vPrevAtx, prevAtx.Blob()))
tab.atxsdata.AddFromAtx(toAtx(t, prevAtx), false)

publishEpoch := posEpoch + 1
Expand Down Expand Up @@ -649,7 +649,7 @@ func TestBuilder_PublishActivationTx_RebuildNIPostWhenTargetEpochPassed(t *testi
currLayer = posEpoch.FirstLayer()
posAtx := newInitialATXv1(t, tab.goldenATXID, func(atx *wire.ActivationTxV1) { atx.PublishEpoch = posEpoch })
posAtx.Sign(sig)
require.NoError(t, atxs.Add(tab.db, toAtx(t, posAtx)))
require.NoError(t, atxs.Add(tab.db, toAtx(t, posAtx), posAtx.Blob()))
tab.atxsdata.AddFromAtx(toAtx(t, posAtx), false)
tab.mclock.EXPECT().CurrentLayer().DoAndReturn(func() types.LayerID { return currLayer }).AnyTimes()
tab.mnipost.EXPECT().ResetState(sig.NodeID()).Return(nil)
Expand Down Expand Up @@ -791,7 +791,7 @@ func TestBuilder_PublishActivationTx_PrevATXWithoutPrevATX(t *testing.T) {
posAtx.Sign(otherSigner)
vPosAtx := toAtx(t, posAtx)
vPosAtx.TickCount = 100
r.NoError(atxs.Add(tab.db, vPosAtx))
r.NoError(atxs.Add(tab.db, vPosAtx, posAtx.Blob()))
tab.atxsdata.AddFromAtx(vPosAtx, false)

nonce := types.VRFPostIndex(123)
Expand All @@ -800,7 +800,7 @@ func TestBuilder_PublishActivationTx_PrevATXWithoutPrevATX(t *testing.T) {
})
prevAtx.Sign(sig)
vPrevAtx := toAtx(t, prevAtx)
r.NoError(atxs.Add(tab.db, vPrevAtx))
r.NoError(atxs.Add(tab.db, vPrevAtx, prevAtx.Blob()))
tab.atxsdata.AddFromAtx(vPrevAtx, false)

// Act
Expand Down Expand Up @@ -884,7 +884,7 @@ func TestBuilder_PublishActivationTx_TargetsEpochBasedOnPosAtx(t *testing.T) {
posEpoch := postGenesisEpoch
posAtx := newInitialATXv1(t, tab.goldenATXID)
posAtx.Sign(otherSigner)
r.NoError(atxs.Add(tab.db, toAtx(t, posAtx)))
r.NoError(atxs.Add(tab.db, toAtx(t, posAtx), posAtx.Blob()))
tab.atxsdata.AddFromAtx(toAtx(t, posAtx), false)

// Act & Assert
Expand Down Expand Up @@ -977,7 +977,7 @@ func TestBuilder_PublishActivationTx_FailsWhenNIPostBuilderFails(t *testing.T) {
currLayer := posEpoch.FirstLayer()
prevAtx := newInitialATXv1(t, tab.goldenATXID)
prevAtx.Sign(sig)
require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx)))
require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx), prevAtx.Blob()))
tab.atxsdata.AddFromAtx(toAtx(t, prevAtx), false)

tab.mclock.EXPECT().CurrentLayer().Return(posEpoch.FirstLayer()).AnyTimes()
Expand Down Expand Up @@ -1035,7 +1035,7 @@ func TestBuilder_RetryPublishActivationTx(t *testing.T) {
sig := maps.Values(tab.signers)[0]
prevAtx := newInitialATXv1(t, tab.goldenATXID)
prevAtx.Sign(sig)
require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx)))
require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx), prevAtx.Blob()))
tab.atxsdata.AddFromAtx(toAtx(t, prevAtx), false)

currLayer := prevAtx.PublishEpoch.FirstLayer()
Expand Down Expand Up @@ -1345,7 +1345,7 @@ func TestGetPositioningAtxPicksAtxWithValidChain(t *testing.T) {
vInvalidAtx := toAtx(t, invalidAtx)
vInvalidAtx.TickCount = 100
require.NoError(t, err)
require.NoError(t, atxs.Add(tab.db, vInvalidAtx))
require.NoError(t, atxs.Add(tab.db, vInvalidAtx, invalidAtx.Blob()))
tab.atxsdata.AddFromAtx(vInvalidAtx, false)

// Valid chain with lower height
Expand All @@ -1355,7 +1355,7 @@ func TestGetPositioningAtxPicksAtxWithValidChain(t *testing.T) {
validAtx.NumUnits += 10
validAtx.Sign(sigValid)
vValidAtx := toAtx(t, validAtx)
require.NoError(t, atxs.Add(tab.db, vValidAtx))
require.NoError(t, atxs.Add(tab.db, vValidAtx, validAtx.Blob()))
tab.atxsdata.AddFromAtx(vValidAtx, false)

tab.mValidator.EXPECT().
Expand Down Expand Up @@ -1419,7 +1419,7 @@ func TestGetPositioningAtx(t *testing.T) {

atxInDb := &types.ActivationTx{TickCount: 10}
atxInDb.SetID(types.RandomATXID())
require.NoError(t, atxs.Add(tab.db, atxInDb))
require.NoError(t, atxs.Add(tab.db, atxInDb, types.AtxBlob{}))
tab.atxsdata.AddFromAtx(atxInDb, false)

prev := &types.ActivationTx{TickCount: 100}
Expand All @@ -1446,7 +1446,7 @@ func TestGetPositioningAtx(t *testing.T) {

atxInDb := &types.ActivationTx{TickCount: 100}
atxInDb.SetID(types.RandomATXID())
require.NoError(t, atxs.Add(tab.db, atxInDb))
require.NoError(t, atxs.Add(tab.db, atxInDb, types.AtxBlob{}))
tab.atxsdata.AddFromAtx(atxInDb, false)

prev := &types.ActivationTx{TickCount: 90}
Expand Down
2 changes: 1 addition & 1 deletion activation/builder_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func TestBuilder_SwitchesToBuildV2(t *testing.T) {

prevAtx := newInitialATXv1(t, tab.goldenATXID)
prevAtx.Sign(sig)
require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx)))
require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx), prevAtx.Blob()))

posEpoch := prevAtx.PublishEpoch
layer := posEpoch.FirstLayer()
Expand Down
2 changes: 1 addition & 1 deletion activation/certifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func TestObtainingPost(t *testing.T) {

atx := newInitialATXv1(t, types.RandomATXID())
atx.SmesherID = id
require.NoError(t, atxs.Add(db, toAtx(t, atx)))
require.NoError(t, atxs.Add(db, toAtx(t, atx), atx.Blob()))

certifier := NewCertifierClient(db, localDb, zaptest.NewLogger(t))
got, err := certifier.obtainPost(context.Background(), id)
Expand Down
2 changes: 1 addition & 1 deletion activation/e2e/activation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func Test_BuilderWithMultipleClients(t *testing.T) {
require.NoError(t, err)
}
logger.Debug("persisting ATX", zap.Inline(atx))
require.NoError(t, atxs.Add(db, atx))
require.NoError(t, atxs.Add(db, atx, gotAtx.Blob()))
data.AddFromAtx(atx, false)

if atxsPublished.Add(1) == totalAtxs {
Expand Down
26 changes: 11 additions & 15 deletions activation/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,28 +239,24 @@ type opaqueAtx interface {
ID() types.ATXID
}

func (h *Handler) decodeATX(msg []byte) (opaqueAtx, error) {
func (h *Handler) decodeATX(msg []byte) (atx opaqueAtx, err error) {
version, err := h.determineVersion(msg)
if err != nil {
return nil, fmt.Errorf("determining ATX version: %w", err)
}

switch *version {
case types.AtxV1:
var atx wire.ActivationTxV1
if err := codec.Decode(msg, &atx); err != nil {
return nil, fmt.Errorf("%w: %w", errMalformedData, err)
}
return &atx, nil
atx, err = wire.DecodeAtxV1(msg)
case types.AtxV2:
var atx wire.ActivationTxV2
if err := codec.Decode(msg, &atx); err != nil {
return nil, fmt.Errorf("%w: %w", errMalformedData, err)
}
return &atx, nil
atx, err = wire.DecodeAtxV2(msg)
default:
return nil, fmt.Errorf("unsupported ATX version: %v", *version)
}

return nil, fmt.Errorf("unsupported ATX version: %v", *version)
if err != nil {
return nil, fmt.Errorf("%w: %w", errMalformedData, err)
}
return atx, nil
}

func (h *Handler) handleAtx(
Expand Down Expand Up @@ -316,9 +312,9 @@ func (h *Handler) handleAtx(

switch atx := opaqueAtx.(type) {
case *wire.ActivationTxV1:
proof, err = h.v1.processATX(ctx, peer, atx, msg, receivedTime)
proof, err = h.v1.processATX(ctx, peer, atx, receivedTime)
case *wire.ActivationTxV2:
proof, err = h.v2.processATX(ctx, peer, atx, msg, receivedTime)
proof, err = h.v2.processATX(ctx, peer, atx, receivedTime)
default:
panic("unreachable")
}
Expand Down
6 changes: 3 additions & 3 deletions activation/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ func TestHandler_DecodeATX(t *testing.T) {
atxHdlr := newTestHandler(t, types.RandomATXID())

atx := newInitialATXv1(t, atxHdlr.goldenATXID)
decoded, err := atxHdlr.decodeATX(codec.MustEncode(atx))
decoded, err := atxHdlr.decodeATX(atx.Blob().Blob)
require.NoError(t, err)
require.Equal(t, atx, decoded)
})
Expand All @@ -880,7 +880,7 @@ func TestHandler_DecodeATX(t *testing.T) {

atx := newInitialATXv2(t, atxHdlr.goldenATXID)
atx.PublishEpoch = 10
decoded, err := atxHdlr.decodeATX(codec.MustEncode(atx))
decoded, err := atxHdlr.decodeATX(atx.Blob().Blob)
require.NoError(t, err)
require.Equal(t, atx, decoded)
})
Expand All @@ -891,7 +891,7 @@ func TestHandler_DecodeATX(t *testing.T) {

atx := newInitialATXv2(t, atxHdlr.goldenATXID)
atx.PublishEpoch = 9
_, err := atxHdlr.decodeATX(codec.MustEncode(atx))
_, err := atxHdlr.decodeATX(atx.Blob().Blob)
require.ErrorIs(t, err, errMalformedData)
})
}
5 changes: 2 additions & 3 deletions activation/handler_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ func (h *HandlerV1) storeAtx(
return fmt.Errorf("check malicious: %w", err)
}

err = atxs.Add(tx, atx)
err = atxs.Add(tx, atx, watx.Blob())
if err != nil && !errors.Is(err, sql.ErrObjectExists) {
return fmt.Errorf("add atx to db: %w", err)
}
Expand Down Expand Up @@ -587,7 +587,6 @@ func (h *HandlerV1) processATX(
ctx context.Context,
peer p2p.Peer,
watx *wire.ActivationTxV1,
blob []byte,
received time.Time,
) (*mwire.MalfeasanceProof, error) {
if !h.edVerifier.Verify(signing.ATX, watx.SmesherID, watx.SignedBytes(), watx.Signature) {
Expand Down Expand Up @@ -644,7 +643,7 @@ func (h *HandlerV1) processATX(
baseTickHeight = posAtx.TickHeight()
}

atx := wire.ActivationTxFromWireV1(watx, blob...)
atx := wire.ActivationTxFromWireV1(watx)
if h.nipostValidator.IsVerifyingFullPost() {
atx.SetValidity(types.Valid)
}
Expand Down
15 changes: 6 additions & 9 deletions activation/handler_v1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (

"github.com/spacemeshos/go-spacemesh/activation/wire"
"github.com/spacemeshos/go-spacemesh/atxsdata"
"github.com/spacemeshos/go-spacemesh/codec"
"github.com/spacemeshos/go-spacemesh/common/types"
"github.com/spacemeshos/go-spacemesh/datastore"
mwire "github.com/spacemeshos/go-spacemesh/malfeasance/wire"
Expand Down Expand Up @@ -67,7 +66,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) {
prevAtx.NumUnits = 100
prevAtx.Sign(sig)
atxHdlr.expectAtxV1(prevAtx, sig.NodeID())
_, err := atxHdlr.processATX(context.Background(), "", prevAtx, codec.MustEncode(prevAtx), time.Now())
_, err := atxHdlr.processATX(context.Background(), "", prevAtx, time.Now())
require.NoError(t, err)

otherSig, err := signing.NewEdSigner()
Expand All @@ -76,7 +75,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) {
posAtx := newInitialATXv1(t, goldenATXID)
posAtx.Sign(otherSig)
atxHdlr.expectAtxV1(posAtx, otherSig.NodeID())
_, err = atxHdlr.processATX(context.Background(), "", posAtx, codec.MustEncode(posAtx), time.Now())
_, err = atxHdlr.processATX(context.Background(), "", posAtx, time.Now())
require.NoError(t, err)
return atxHdlr, prevAtx, posAtx
}
Expand Down Expand Up @@ -488,14 +487,14 @@ func TestHandler_ContextuallyValidateAtx(t *testing.T) {
atx0 := newInitialATXv1(t, goldenATXID)
atx0.Sign(sig)
atxHdlr.expectAtxV1(atx0, sig.NodeID())
_, err := atxHdlr.processATX(context.Background(), "", atx0, codec.MustEncode(atx0), time.Now())
_, err := atxHdlr.processATX(context.Background(), "", atx0, time.Now())
require.NoError(t, err)

atx1 := newChainedActivationTxV1(t, atx0, goldenATXID)
atx1.Sign(sig)
atxHdlr.expectAtxV1(atx1, sig.NodeID())
atxHdlr.mockFetch.EXPECT().GetAtxs(gomock.Any(), gomock.Any(), gomock.Any())
_, err = atxHdlr.processATX(context.Background(), "", atx1, codec.MustEncode(atx1), time.Now())
_, err = atxHdlr.processATX(context.Background(), "", atx1, time.Now())
require.NoError(t, err)

atxInvalidPrevious := newChainedActivationTxV1(t, atx0, goldenATXID)
Expand All @@ -515,13 +514,13 @@ func TestHandler_ContextuallyValidateAtx(t *testing.T) {
atx0 := newInitialATXv1(t, goldenATXID)
atx0.Sign(otherSig)
atxHdlr.expectAtxV1(atx0, otherSig.NodeID())
_, err = atxHdlr.processATX(context.Background(), "", atx0, codec.MustEncode(atx0), time.Now())
_, err = atxHdlr.processATX(context.Background(), "", atx0, time.Now())
require.NoError(t, err)

atx1 := newInitialATXv1(t, goldenATXID)
atx1.Sign(sig)
atxHdlr.expectAtxV1(atx1, sig.NodeID())
_, err = atxHdlr.processATX(context.Background(), "", atx1, codec.MustEncode(atx1), time.Now())
_, err = atxHdlr.processATX(context.Background(), "", atx1, time.Now())
require.NoError(t, err)

atxInvalidPrevious := newChainedActivationTxV1(t, atx0, goldenATXID)
Expand Down Expand Up @@ -555,7 +554,6 @@ func TestHandlerV1_StoreAtx(t *testing.T) {
atxFromDb, err := atxs.Get(atxHdlr.cdb, atx.ID())
require.NoError(t, err)
atx.SetReceived(time.Unix(0, atx.Received().UnixNano()))
atx.AtxBlob = types.AtxBlob{}
require.Equal(t, atx, atxFromDb)
})

Expand Down Expand Up @@ -605,7 +603,6 @@ func TestHandlerV1_StoreAtx(t *testing.T) {
atxFromDb, err := atxs.Get(atxHdlr.cdb, atx.ID())
require.NoError(t, err)
atx.SetReceived(time.Unix(0, atx.Received().UnixNano()))
atx.AtxBlob = types.AtxBlob{}
require.Equal(t, atx, atxFromDb)
})

Expand Down
4 changes: 1 addition & 3 deletions activation/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ func (h *HandlerV2) processATX(
ctx context.Context,
peer p2p.Peer,
watx *wire.ActivationTxV2,
blob []byte,
received time.Time,
) (*mwire.MalfeasanceProof, error) {
exists, err := atxs.Has(h.cdb, watx.ID())
Expand Down Expand Up @@ -129,7 +128,6 @@ func (h *HandlerV2) processATX(
Weight: parts.weight,
VRFNonce: types.VRFPostIndex(watx.VRFNonce),
SmesherID: watx.SmesherID,
AtxBlob: types.AtxBlob{Blob: blob, Version: types.AtxV2},
}

if watx.Initial == nil {
Expand Down Expand Up @@ -746,7 +744,7 @@ func (h *HandlerV2) storeAtx(
}
}

err = atxs.Add(tx, atx)
err = atxs.Add(tx, atx, watx.Blob())
if err != nil && !errors.Is(err, sql.ErrObjectExists) {
return fmt.Errorf("add atx to db: %w", err)
}
Expand Down
Loading

0 comments on commit 5dbae68

Please sign in to comment.