diff --git a/activation/activation_multi_test.go b/activation/activation_multi_test.go index 830da153f1..ce795e53b5 100644 --- a/activation/activation_multi_test.go +++ b/activation/activation_multi_test.go @@ -197,7 +197,7 @@ func TestRegossip(t *testing.T) { atx.PublishEpoch = layer.GetEpoch() atx.Sign(sig) vAtx := toAtx(t, atx) - require.NoError(t, atxs.Add(tab.db, vAtx)) + require.NoError(t, atxs.Add(tab.db, vAtx, atx.Blob())) if refAtx == nil { refAtx = vAtx diff --git a/activation/activation_test.go b/activation/activation_test.go index 299882eda1..a3c696cb24 100644 --- a/activation/activation_test.go +++ b/activation/activation_test.go @@ -136,7 +136,7 @@ func publishAtxV1( func(_ context.Context, _ string, got []byte) error { return codec.Decode(got, &watx) }) - require.NoError(tb, atxs.Add(tab.db, toAtx(tb, &watx))) + require.NoError(tb, atxs.Add(tab.db, toAtx(tb, &watx), watx.Blob())) tab.atxsdata.AddFromAtx(toAtx(tb, &watx), false) return &watx } @@ -351,7 +351,7 @@ func TestBuilder_PublishActivationTx_HappyFlow(t *testing.T) { currLayer := posEpoch.FirstLayer() prevAtx := newInitialATXv1(t, tab.goldenATXID) prevAtx.Sign(sig) - require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx))) + require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx), prevAtx.Blob())) tab.atxsdata.AddFromAtx(toAtx(t, prevAtx), false) // create and publish ATX @@ -387,7 +387,7 @@ func TestBuilder_Loop_WaitsOnStaleChallenge(t *testing.T) { currLayer := (postGenesisEpoch + 1).FirstLayer() prevAtx := newInitialATXv1(t, tab.goldenATXID) prevAtx.Sign(sig) - require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx))) + require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx), prevAtx.Blob())) tab.atxsdata.AddFromAtx(toAtx(t, prevAtx), false) tab.mclock.EXPECT().CurrentLayer().Return(currLayer).AnyTimes() @@ -436,7 +436,7 @@ func TestBuilder_PublishActivationTx_FaultyNet(t *testing.T) { currLayer := postGenesisEpoch.FirstLayer() prevAtx := newInitialATXv1(t, tab.goldenATXID) prevAtx.Sign(sig) - require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx))) + require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx), prevAtx.Blob())) tab.atxsdata.AddFromAtx(toAtx(t, prevAtx), false) publishEpoch := posEpoch + 1 @@ -511,7 +511,7 @@ func TestBuilder_PublishActivationTx_UsesExistingChallengeOnLatePublish(t *testi prevAtx := newInitialATXv1(t, tab.goldenATXID) prevAtx.Sign(sig) vPrevAtx := toAtx(t, prevAtx) - require.NoError(t, atxs.Add(tab.db, vPrevAtx)) + require.NoError(t, atxs.Add(tab.db, vPrevAtx, prevAtx.Blob())) tab.atxsdata.AddFromAtx(toAtx(t, prevAtx), false) publishEpoch := currLayer.GetEpoch() @@ -588,7 +588,7 @@ func TestBuilder_PublishActivationTx_RebuildNIPostWhenTargetEpochPassed(t *testi prevAtx := newInitialATXv1(t, tab.goldenATXID) prevAtx.Sign(sig) vPrevAtx := toAtx(t, prevAtx) - require.NoError(t, atxs.Add(tab.db, vPrevAtx)) + require.NoError(t, atxs.Add(tab.db, vPrevAtx, prevAtx.Blob())) tab.atxsdata.AddFromAtx(toAtx(t, prevAtx), false) publishEpoch := posEpoch + 1 @@ -649,7 +649,7 @@ func TestBuilder_PublishActivationTx_RebuildNIPostWhenTargetEpochPassed(t *testi currLayer = posEpoch.FirstLayer() posAtx := newInitialATXv1(t, tab.goldenATXID, func(atx *wire.ActivationTxV1) { atx.PublishEpoch = posEpoch }) posAtx.Sign(sig) - require.NoError(t, atxs.Add(tab.db, toAtx(t, posAtx))) + require.NoError(t, atxs.Add(tab.db, toAtx(t, posAtx), posAtx.Blob())) tab.atxsdata.AddFromAtx(toAtx(t, posAtx), false) tab.mclock.EXPECT().CurrentLayer().DoAndReturn(func() types.LayerID { return currLayer }).AnyTimes() tab.mnipost.EXPECT().ResetState(sig.NodeID()).Return(nil) @@ -791,7 +791,7 @@ func TestBuilder_PublishActivationTx_PrevATXWithoutPrevATX(t *testing.T) { posAtx.Sign(otherSigner) vPosAtx := toAtx(t, posAtx) vPosAtx.TickCount = 100 - r.NoError(atxs.Add(tab.db, vPosAtx)) + r.NoError(atxs.Add(tab.db, vPosAtx, posAtx.Blob())) tab.atxsdata.AddFromAtx(vPosAtx, false) nonce := types.VRFPostIndex(123) @@ -800,7 +800,7 @@ func TestBuilder_PublishActivationTx_PrevATXWithoutPrevATX(t *testing.T) { }) prevAtx.Sign(sig) vPrevAtx := toAtx(t, prevAtx) - r.NoError(atxs.Add(tab.db, vPrevAtx)) + r.NoError(atxs.Add(tab.db, vPrevAtx, prevAtx.Blob())) tab.atxsdata.AddFromAtx(vPrevAtx, false) // Act @@ -884,7 +884,7 @@ func TestBuilder_PublishActivationTx_TargetsEpochBasedOnPosAtx(t *testing.T) { posEpoch := postGenesisEpoch posAtx := newInitialATXv1(t, tab.goldenATXID) posAtx.Sign(otherSigner) - r.NoError(atxs.Add(tab.db, toAtx(t, posAtx))) + r.NoError(atxs.Add(tab.db, toAtx(t, posAtx), posAtx.Blob())) tab.atxsdata.AddFromAtx(toAtx(t, posAtx), false) // Act & Assert @@ -977,7 +977,7 @@ func TestBuilder_PublishActivationTx_FailsWhenNIPostBuilderFails(t *testing.T) { currLayer := posEpoch.FirstLayer() prevAtx := newInitialATXv1(t, tab.goldenATXID) prevAtx.Sign(sig) - require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx))) + require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx), prevAtx.Blob())) tab.atxsdata.AddFromAtx(toAtx(t, prevAtx), false) tab.mclock.EXPECT().CurrentLayer().Return(posEpoch.FirstLayer()).AnyTimes() @@ -1035,7 +1035,7 @@ func TestBuilder_RetryPublishActivationTx(t *testing.T) { sig := maps.Values(tab.signers)[0] prevAtx := newInitialATXv1(t, tab.goldenATXID) prevAtx.Sign(sig) - require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx))) + require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx), prevAtx.Blob())) tab.atxsdata.AddFromAtx(toAtx(t, prevAtx), false) currLayer := prevAtx.PublishEpoch.FirstLayer() @@ -1345,7 +1345,7 @@ func TestGetPositioningAtxPicksAtxWithValidChain(t *testing.T) { vInvalidAtx := toAtx(t, invalidAtx) vInvalidAtx.TickCount = 100 require.NoError(t, err) - require.NoError(t, atxs.Add(tab.db, vInvalidAtx)) + require.NoError(t, atxs.Add(tab.db, vInvalidAtx, invalidAtx.Blob())) tab.atxsdata.AddFromAtx(vInvalidAtx, false) // Valid chain with lower height @@ -1355,7 +1355,7 @@ func TestGetPositioningAtxPicksAtxWithValidChain(t *testing.T) { validAtx.NumUnits += 10 validAtx.Sign(sigValid) vValidAtx := toAtx(t, validAtx) - require.NoError(t, atxs.Add(tab.db, vValidAtx)) + require.NoError(t, atxs.Add(tab.db, vValidAtx, validAtx.Blob())) tab.atxsdata.AddFromAtx(vValidAtx, false) tab.mValidator.EXPECT(). @@ -1419,7 +1419,7 @@ func TestGetPositioningAtx(t *testing.T) { atxInDb := &types.ActivationTx{TickCount: 10} atxInDb.SetID(types.RandomATXID()) - require.NoError(t, atxs.Add(tab.db, atxInDb)) + require.NoError(t, atxs.Add(tab.db, atxInDb, types.AtxBlob{})) tab.atxsdata.AddFromAtx(atxInDb, false) prev := &types.ActivationTx{TickCount: 100} @@ -1446,7 +1446,7 @@ func TestGetPositioningAtx(t *testing.T) { atxInDb := &types.ActivationTx{TickCount: 100} atxInDb.SetID(types.RandomATXID()) - require.NoError(t, atxs.Add(tab.db, atxInDb)) + require.NoError(t, atxs.Add(tab.db, atxInDb, types.AtxBlob{})) tab.atxsdata.AddFromAtx(atxInDb, false) prev := &types.ActivationTx{TickCount: 90} diff --git a/activation/builder_v2_test.go b/activation/builder_v2_test.go index 06e87899bb..209570b7c4 100644 --- a/activation/builder_v2_test.go +++ b/activation/builder_v2_test.go @@ -79,7 +79,7 @@ func TestBuilder_SwitchesToBuildV2(t *testing.T) { prevAtx := newInitialATXv1(t, tab.goldenATXID) prevAtx.Sign(sig) - require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx))) + require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx), prevAtx.Blob())) posEpoch := prevAtx.PublishEpoch layer := posEpoch.FirstLayer() diff --git a/activation/certifier_test.go b/activation/certifier_test.go index cc329737da..fa0db876c0 100644 --- a/activation/certifier_test.go +++ b/activation/certifier_test.go @@ -147,7 +147,7 @@ func TestObtainingPost(t *testing.T) { atx := newInitialATXv1(t, types.RandomATXID()) atx.SmesherID = id - require.NoError(t, atxs.Add(db, toAtx(t, atx))) + require.NoError(t, atxs.Add(db, toAtx(t, atx), atx.Blob())) certifier := NewCertifierClient(db, localDb, zaptest.NewLogger(t)) got, err := certifier.obtainPost(context.Background(), id) diff --git a/activation/e2e/activation_test.go b/activation/e2e/activation_test.go index c58cca1f47..bf7557e9f7 100644 --- a/activation/e2e/activation_test.go +++ b/activation/e2e/activation_test.go @@ -176,7 +176,7 @@ func Test_BuilderWithMultipleClients(t *testing.T) { require.NoError(t, err) } logger.Debug("persisting ATX", zap.Inline(atx)) - require.NoError(t, atxs.Add(db, atx)) + require.NoError(t, atxs.Add(db, atx, gotAtx.Blob())) data.AddFromAtx(atx, false) if atxsPublished.Add(1) == totalAtxs { diff --git a/activation/handler.go b/activation/handler.go index 71f11ff72c..c828b45391 100644 --- a/activation/handler.go +++ b/activation/handler.go @@ -239,7 +239,7 @@ type opaqueAtx interface { ID() types.ATXID } -func (h *Handler) decodeATX(msg []byte) (opaqueAtx, error) { +func (h *Handler) decodeATX(msg []byte) (atx opaqueAtx, err error) { version, err := h.determineVersion(msg) if err != nil { return nil, fmt.Errorf("determining ATX version: %w", err) @@ -247,20 +247,16 @@ func (h *Handler) decodeATX(msg []byte) (opaqueAtx, error) { switch *version { case types.AtxV1: - var atx wire.ActivationTxV1 - if err := codec.Decode(msg, &atx); err != nil { - return nil, fmt.Errorf("%w: %w", errMalformedData, err) - } - return &atx, nil + atx, err = wire.DecodeAtxV1(msg) case types.AtxV2: - var atx wire.ActivationTxV2 - if err := codec.Decode(msg, &atx); err != nil { - return nil, fmt.Errorf("%w: %w", errMalformedData, err) - } - return &atx, nil + atx, err = wire.DecodeAtxV2(msg) + default: + return nil, fmt.Errorf("unsupported ATX version: %v", *version) } - - return nil, fmt.Errorf("unsupported ATX version: %v", *version) + if err != nil { + return nil, fmt.Errorf("%w: %w", errMalformedData, err) + } + return atx, nil } func (h *Handler) handleAtx( @@ -316,9 +312,9 @@ func (h *Handler) handleAtx( switch atx := opaqueAtx.(type) { case *wire.ActivationTxV1: - proof, err = h.v1.processATX(ctx, peer, atx, msg, receivedTime) + proof, err = h.v1.processATX(ctx, peer, atx, receivedTime) case *wire.ActivationTxV2: - proof, err = h.v2.processATX(ctx, peer, atx, msg, receivedTime) + proof, err = h.v2.processATX(ctx, peer, atx, receivedTime) default: panic("unreachable") } diff --git a/activation/handler_test.go b/activation/handler_test.go index 30c0681dc2..9b71a9f81f 100644 --- a/activation/handler_test.go +++ b/activation/handler_test.go @@ -869,7 +869,7 @@ func TestHandler_DecodeATX(t *testing.T) { atxHdlr := newTestHandler(t, types.RandomATXID()) atx := newInitialATXv1(t, atxHdlr.goldenATXID) - decoded, err := atxHdlr.decodeATX(codec.MustEncode(atx)) + decoded, err := atxHdlr.decodeATX(atx.Blob().Blob) require.NoError(t, err) require.Equal(t, atx, decoded) }) @@ -880,7 +880,7 @@ func TestHandler_DecodeATX(t *testing.T) { atx := newInitialATXv2(t, atxHdlr.goldenATXID) atx.PublishEpoch = 10 - decoded, err := atxHdlr.decodeATX(codec.MustEncode(atx)) + decoded, err := atxHdlr.decodeATX(atx.Blob().Blob) require.NoError(t, err) require.Equal(t, atx, decoded) }) @@ -891,7 +891,7 @@ func TestHandler_DecodeATX(t *testing.T) { atx := newInitialATXv2(t, atxHdlr.goldenATXID) atx.PublishEpoch = 9 - _, err := atxHdlr.decodeATX(codec.MustEncode(atx)) + _, err := atxHdlr.decodeATX(atx.Blob().Blob) require.ErrorIs(t, err, errMalformedData) }) } diff --git a/activation/handler_v1.go b/activation/handler_v1.go index 847739699e..1ba7c90d2d 100644 --- a/activation/handler_v1.go +++ b/activation/handler_v1.go @@ -550,7 +550,7 @@ func (h *HandlerV1) storeAtx( return fmt.Errorf("check malicious: %w", err) } - err = atxs.Add(tx, atx) + err = atxs.Add(tx, atx, watx.Blob()) if err != nil && !errors.Is(err, sql.ErrObjectExists) { return fmt.Errorf("add atx to db: %w", err) } @@ -587,7 +587,6 @@ func (h *HandlerV1) processATX( ctx context.Context, peer p2p.Peer, watx *wire.ActivationTxV1, - blob []byte, received time.Time, ) (*mwire.MalfeasanceProof, error) { if !h.edVerifier.Verify(signing.ATX, watx.SmesherID, watx.SignedBytes(), watx.Signature) { @@ -644,7 +643,7 @@ func (h *HandlerV1) processATX( baseTickHeight = posAtx.TickHeight() } - atx := wire.ActivationTxFromWireV1(watx, blob...) + atx := wire.ActivationTxFromWireV1(watx) if h.nipostValidator.IsVerifyingFullPost() { atx.SetValidity(types.Valid) } diff --git a/activation/handler_v1_test.go b/activation/handler_v1_test.go index 02cef11cd4..97ce849ceb 100644 --- a/activation/handler_v1_test.go +++ b/activation/handler_v1_test.go @@ -13,7 +13,6 @@ import ( "github.com/spacemeshos/go-spacemesh/activation/wire" "github.com/spacemeshos/go-spacemesh/atxsdata" - "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/datastore" mwire "github.com/spacemeshos/go-spacemesh/malfeasance/wire" @@ -67,7 +66,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { prevAtx.NumUnits = 100 prevAtx.Sign(sig) atxHdlr.expectAtxV1(prevAtx, sig.NodeID()) - _, err := atxHdlr.processATX(context.Background(), "", prevAtx, codec.MustEncode(prevAtx), time.Now()) + _, err := atxHdlr.processATX(context.Background(), "", prevAtx, time.Now()) require.NoError(t, err) otherSig, err := signing.NewEdSigner() @@ -76,7 +75,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { posAtx := newInitialATXv1(t, goldenATXID) posAtx.Sign(otherSig) atxHdlr.expectAtxV1(posAtx, otherSig.NodeID()) - _, err = atxHdlr.processATX(context.Background(), "", posAtx, codec.MustEncode(posAtx), time.Now()) + _, err = atxHdlr.processATX(context.Background(), "", posAtx, time.Now()) require.NoError(t, err) return atxHdlr, prevAtx, posAtx } @@ -488,14 +487,14 @@ func TestHandler_ContextuallyValidateAtx(t *testing.T) { atx0 := newInitialATXv1(t, goldenATXID) atx0.Sign(sig) atxHdlr.expectAtxV1(atx0, sig.NodeID()) - _, err := atxHdlr.processATX(context.Background(), "", atx0, codec.MustEncode(atx0), time.Now()) + _, err := atxHdlr.processATX(context.Background(), "", atx0, time.Now()) require.NoError(t, err) atx1 := newChainedActivationTxV1(t, atx0, goldenATXID) atx1.Sign(sig) atxHdlr.expectAtxV1(atx1, sig.NodeID()) atxHdlr.mockFetch.EXPECT().GetAtxs(gomock.Any(), gomock.Any(), gomock.Any()) - _, err = atxHdlr.processATX(context.Background(), "", atx1, codec.MustEncode(atx1), time.Now()) + _, err = atxHdlr.processATX(context.Background(), "", atx1, time.Now()) require.NoError(t, err) atxInvalidPrevious := newChainedActivationTxV1(t, atx0, goldenATXID) @@ -515,13 +514,13 @@ func TestHandler_ContextuallyValidateAtx(t *testing.T) { atx0 := newInitialATXv1(t, goldenATXID) atx0.Sign(otherSig) atxHdlr.expectAtxV1(atx0, otherSig.NodeID()) - _, err = atxHdlr.processATX(context.Background(), "", atx0, codec.MustEncode(atx0), time.Now()) + _, err = atxHdlr.processATX(context.Background(), "", atx0, time.Now()) require.NoError(t, err) atx1 := newInitialATXv1(t, goldenATXID) atx1.Sign(sig) atxHdlr.expectAtxV1(atx1, sig.NodeID()) - _, err = atxHdlr.processATX(context.Background(), "", atx1, codec.MustEncode(atx1), time.Now()) + _, err = atxHdlr.processATX(context.Background(), "", atx1, time.Now()) require.NoError(t, err) atxInvalidPrevious := newChainedActivationTxV1(t, atx0, goldenATXID) @@ -555,7 +554,6 @@ func TestHandlerV1_StoreAtx(t *testing.T) { atxFromDb, err := atxs.Get(atxHdlr.cdb, atx.ID()) require.NoError(t, err) atx.SetReceived(time.Unix(0, atx.Received().UnixNano())) - atx.AtxBlob = types.AtxBlob{} require.Equal(t, atx, atxFromDb) }) @@ -605,7 +603,6 @@ func TestHandlerV1_StoreAtx(t *testing.T) { atxFromDb, err := atxs.Get(atxHdlr.cdb, atx.ID()) require.NoError(t, err) atx.SetReceived(time.Unix(0, atx.Received().UnixNano())) - atx.AtxBlob = types.AtxBlob{} require.Equal(t, atx, atxFromDb) }) diff --git a/activation/handler_v2.go b/activation/handler_v2.go index 1d4fd60754..73aca06063 100644 --- a/activation/handler_v2.go +++ b/activation/handler_v2.go @@ -72,7 +72,6 @@ func (h *HandlerV2) processATX( ctx context.Context, peer p2p.Peer, watx *wire.ActivationTxV2, - blob []byte, received time.Time, ) (*mwire.MalfeasanceProof, error) { exists, err := atxs.Has(h.cdb, watx.ID()) @@ -129,7 +128,6 @@ func (h *HandlerV2) processATX( Weight: parts.weight, VRFNonce: types.VRFPostIndex(watx.VRFNonce), SmesherID: watx.SmesherID, - AtxBlob: types.AtxBlob{Blob: blob, Version: types.AtxV2}, } if watx.Initial == nil { @@ -746,7 +744,7 @@ func (h *HandlerV2) storeAtx( } } - err = atxs.Add(tx, atx) + err = atxs.Add(tx, atx, watx.Blob()) if err != nil && !errors.Is(err, sql.ErrObjectExists) { return fmt.Errorf("add atx to db: %w", err) } diff --git a/activation/handler_v2_test.go b/activation/handler_v2_test.go index 3718f2bd25..4f80f9e9fb 100644 --- a/activation/handler_v2_test.go +++ b/activation/handler_v2_test.go @@ -17,7 +17,6 @@ import ( "github.com/spacemeshos/go-spacemesh/activation/wire" "github.com/spacemeshos/go-spacemesh/atxsdata" - "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/datastore" mwire "github.com/spacemeshos/go-spacemesh/malfeasance/wire" @@ -191,13 +190,13 @@ func (h *v2TestHandler) createAndProcessInitial(t testing.TB, sig *signing.EdSig func (h *v2TestHandler) processInitial(t testing.TB, atx *wire.ActivationTxV2) (*mwire.MalfeasanceProof, error) { t.Helper() h.expectInitialAtxV2(atx) - return h.processATX(context.Background(), peer.ID("peer"), atx, codec.MustEncode(atx), time.Now()) + return h.processATX(context.Background(), peer.ID("peer"), atx, time.Now()) } func (h *v2TestHandler) processSoloAtx(t testing.TB, atx *wire.ActivationTxV2) (*mwire.MalfeasanceProof, error) { t.Helper() h.expectAtxV2(atx) - return h.processATX(context.Background(), peer.ID("peer"), atx, codec.MustEncode(atx), time.Now()) + return h.processATX(context.Background(), peer.ID("peer"), atx, time.Now()) } func TestHandlerV2_SyntacticallyValidate(t *testing.T) { @@ -461,13 +460,12 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) { t.Parallel() atx := newInitialATXv2(t, golden) atx.Sign(sig) - blob := codec.MustEncode(atx) atxHandler := newV2TestHandler(t, golden) atxHandler.tickSize = tickSize atxHandler.expectInitialAtxV2(atx) - proof, err := atxHandler.processATX(context.Background(), peer, atx, blob, time.Now()) + proof, err := atxHandler.processATX(context.Background(), peer, atx, time.Now()) require.NoError(t, err) require.Nil(t, proof) @@ -482,7 +480,7 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) { require.EqualValues(t, atx.NiPosts[0].Posts[0].NumUnits*poetLeaves/tickSize, atxFromDb.Weight) // processing ATX for the second time should skip checks - proof, err = atxHandler.processATX(context.Background(), peer, atx, blob, time.Now()) + proof, err = atxHandler.processATX(context.Background(), peer, atx, time.Now()) require.NoError(t, err) require.Nil(t, proof) }) @@ -494,10 +492,9 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) { atx := newSoloATXv2(t, prev.PublishEpoch+1, prev.ID(), prev.ID()) atx.Sign(sig) - blob := codec.MustEncode(atx) atxHandler.expectAtxV2(atx) - proof, err := atxHandler.processATX(context.Background(), peer, atx, blob, time.Now()) + proof, err := atxHandler.processATX(context.Background(), peer, atx, time.Now()) require.NoError(t, err) require.Nil(t, proof) @@ -527,7 +524,7 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) { atx := newSoloATXv2(t, prev.Epoch+1, prev.ID, golden) atx.Sign(sig) atxHandler.expectAtxV2(atx) - _, err := atxHandler.processATX(context.Background(), peer, atx, codec.MustEncode(atx), time.Now()) + _, err := atxHandler.processATX(context.Background(), peer, atx, time.Now()) require.NoError(t, err) atxFromDb, err := atxs.Get(atxHandler.cdb, atx.ID()) @@ -545,7 +542,7 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) { atx.Sign(sig) atxHandler.expectAtxV2(atx) - proof, err := atxHandler.processATX(context.Background(), peer, atx, codec.MustEncode(atx), time.Now()) + proof, err := atxHandler.processATX(context.Background(), peer, atx, time.Now()) require.NoError(t, err) require.Nil(t, proof) @@ -573,7 +570,7 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) { atx.TotalNumUnits(), ).Return(errors.New("vrf nonce is not valid")) - _, err = atxHandler.processATX(context.Background(), peer, atx, codec.MustEncode(atx), time.Now()) + _, err = atxHandler.processATX(context.Background(), peer, atx, time.Now()) require.ErrorContains(t, err, "vrf nonce is not valid") _, err = atxs.Get(atxHandler.cdb, atx.ID()) @@ -590,7 +587,7 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) { atx.Sign(sig) atxHandler.expectAtxV2(atx) - proof, err := atxHandler.processATX(context.Background(), peer, atx, codec.MustEncode(atx), time.Now()) + proof, err := atxHandler.processATX(context.Background(), peer, atx, time.Now()) require.NoError(t, err) require.Nil(t, proof) @@ -608,7 +605,7 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) { atxHandler.mclock.EXPECT().CurrentLayer() atxHandler.expectFetchDeps(atx) - _, err := atxHandler.processATX(context.Background(), peer, atx, codec.MustEncode(atx), time.Now()) + _, err := atxHandler.processATX(context.Background(), peer, atx, time.Now()) require.ErrorContains(t, err, "validating positioning atx") _, err = atxs.Get(atxHandler.cdb, atx.ID()) @@ -641,7 +638,7 @@ func marryIDs( mATX.Sign(sig) atxHandler.expectInitialAtxV2(mATX) - p, err := atxHandler.processATX(context.Background(), "", mATX, codec.MustEncode(mATX), time.Now()) + p, err := atxHandler.processATX(context.Background(), "", mATX, time.Now()) require.NoError(t, err) require.Nil(t, p) @@ -685,7 +682,7 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { merged.Sign(sig) atxHandler.expectMergedAtxV2(merged, equivocationSet, []uint64{poetLeaves}) - p, err := atxHandler.processATX(context.Background(), "", merged, codec.MustEncode(merged), time.Now()) + p, err := atxHandler.processATX(context.Background(), "", merged, time.Now()) require.NoError(t, err) require.Nil(t, p) @@ -749,7 +746,7 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { minPoetLeaves := slices.Min(poetLeaves) atxHandler.expectMergedAtxV2(merged, equivocationSet, poetLeaves) - p, err := atxHandler.processATX(context.Background(), "", merged, codec.MustEncode(merged), time.Now()) + p, err := atxHandler.processATX(context.Background(), "", merged, time.Now()) require.NoError(t, err) require.Nil(t, p) @@ -802,7 +799,7 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { atxHandler.expectFetchDeps(merged) atxHandler.expectVerifyNIPoSTs(merged, equivocationSet, []uint64{200}) - p, err := atxHandler.processATX(context.Background(), "", merged, codec.MustEncode(merged), time.Now()) + p, err := atxHandler.processATX(context.Background(), "", merged, time.Now()) require.ErrorContains(t, err, "ATX signer not present in merged ATX") require.Nil(t, p) }) @@ -836,7 +833,7 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { merged.Sign(sig) atxHandler.mclock.EXPECT().CurrentLayer().Return(merged.PublishEpoch.FirstLayer()) - p, err := atxHandler.processATX(context.Background(), "", merged, codec.MustEncode(merged), time.Now()) + p, err := atxHandler.processATX(context.Background(), "", merged, time.Now()) require.ErrorContains(t, err, "ID present twice (duplicated marriage index)") require.Nil(t, p) }) @@ -869,7 +866,7 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { atxHandler.mclock.EXPECT().CurrentLayer().Return(merged.PublishEpoch.FirstLayer()) atxHandler.expectFetchDeps(merged) - p, err := atxHandler.processATX(context.Background(), "", merged, codec.MustEncode(merged), time.Now()) + p, err := atxHandler.processATX(context.Background(), "", merged, time.Now()) require.Error(t, err) require.Nil(t, p) }) @@ -912,7 +909,7 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { merged.Sign(sig) atxHandler.expectMergedAtxV2(merged, equivocationSet, []uint64{100}) - p, err := atxHandler.processATX(context.Background(), "", merged, codec.MustEncode(merged), time.Now()) + p, err := atxHandler.processATX(context.Background(), "", merged, time.Now()) require.NoError(t, err) require.Nil(t, p) @@ -939,7 +936,7 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { atxHandler.mclock.EXPECT().CurrentLayer().Return(merged.PublishEpoch.FirstLayer()) atxHandler.expectFetchDeps(merged) - p, err = atxHandler.processATX(context.Background(), "", merged, codec.MustEncode(merged), time.Now()) + p, err = atxHandler.processATX(context.Background(), "", merged, time.Now()) require.Error(t, err) require.Nil(t, p) }) @@ -1112,7 +1109,7 @@ func Test_ValidatePositioningAtx(t *testing.T) { BaseTickHeight: 100, } positioningAtx.SetID(types.RandomATXID()) - atxs.Add(atxHandler.cdb, positioningAtx) + atxs.Add(atxHandler.cdb, positioningAtx, types.AtxBlob{}) height, err := atxHandler.validatePositioningAtx(1, golden, positioningAtx.ID()) require.NoError(t, err) @@ -1126,7 +1123,7 @@ func Test_ValidatePositioningAtx(t *testing.T) { PublishEpoch: 1, } positioningAtx.SetID(types.RandomATXID()) - atxs.Add(atxHandler.cdb, positioningAtx) + atxs.Add(atxHandler.cdb, positioningAtx, types.AtxBlob{}) _, err := atxHandler.validatePositioningAtx(1, golden, positioningAtx.ID()) require.Error(t, err) @@ -1139,7 +1136,7 @@ func Test_ValidatePositioningAtx(t *testing.T) { PublishEpoch: 2, } positioningAtx.SetID(types.RandomATXID()) - atxs.Add(atxHandler.cdb, positioningAtx) + atxs.Add(atxHandler.cdb, positioningAtx, types.AtxBlob{}) _, err := atxHandler.validatePositioningAtx(1, golden, positioningAtx.ID()) require.Error(t, err) @@ -1193,7 +1190,7 @@ func Test_ValidateMarriages(t *testing.T) { marriage.Sign(sig) atxHandler.expectInitialAtxV2(marriage) - p, err := atxHandler.processATX(context.Background(), "", marriage, codec.MustEncode(marriage), time.Now()) + p, err := atxHandler.processATX(context.Background(), "", marriage, time.Now()) require.NoError(t, err) require.Nil(t, p) @@ -1227,7 +1224,7 @@ func Test_ValidateMarriages(t *testing.T) { marriage.Sign(sig) atxHandler.expectInitialAtxV2(marriage) - p, err := atxHandler.processATX(context.Background(), "", marriage, codec.MustEncode(marriage), time.Now()) + p, err := atxHandler.processATX(context.Background(), "", marriage, time.Now()) require.NoError(t, err) require.Nil(t, p) @@ -1298,7 +1295,7 @@ func Test_ValidateCommitmentAtx(t *testing.T) { atxHandler := newV2TestHandler(t, golden) commitment := &types.ActivationTx{PublishEpoch: 3} commitment.SetID(types.RandomATXID()) - require.NoError(t, atxs.Add(atxHandler.cdb, commitment)) + require.NoError(t, atxs.Add(atxHandler.cdb, commitment, types.AtxBlob{})) err := atxHandler.validateCommitmentAtx(golden, commitment.ID(), 4) require.NoError(t, err) }) @@ -1307,7 +1304,7 @@ func Test_ValidateCommitmentAtx(t *testing.T) { atxHandler := newV2TestHandler(t, golden) commitment := &types.ActivationTx{PublishEpoch: 3} commitment.SetID(types.RandomATXID()) - require.NoError(t, atxs.Add(atxHandler.cdb, commitment)) + require.NoError(t, atxs.Add(atxHandler.cdb, commitment, types.AtxBlob{})) err := atxHandler.validateCommitmentAtx(golden, commitment.ID(), 3) require.ErrorContains(t, err, "must be after commitment atx") err = atxHandler.validateCommitmentAtx(golden, commitment.ID(), 2) @@ -1593,7 +1590,7 @@ func Test_Marriages(t *testing.T) { atx.Sign(sig) atxHandler.expectInitialAtxV2(atx) - _, err = atxHandler.processATX(context.Background(), "", atx, codec.MustEncode(atx), time.Now()) + _, err = atxHandler.processATX(context.Background(), "", atx, time.Now()) require.NoError(t, err) // otherSig2 cannot marry sig, trying to extend its set. @@ -1616,7 +1613,7 @@ func Test_Marriages(t *testing.T) { for _, id := range ids { atxHandler.mtortoise.EXPECT().OnMalfeasance(id) } - proof, err := atxHandler.processATX(context.Background(), "", atx2, codec.MustEncode(atx2), time.Now()) + proof, err := atxHandler.processATX(context.Background(), "", atx2, time.Now()) require.NoError(t, err) // TODO: check the proof contents once its implemented require.NotNil(t, proof) @@ -1651,7 +1648,7 @@ func Test_Marriages(t *testing.T) { atx.Sign(sig) atxHandler.mclock.EXPECT().CurrentLayer().AnyTimes() - _, err = atxHandler.processATX(context.Background(), "", atx, codec.MustEncode(atx), time.Now()) + _, err = atxHandler.processATX(context.Background(), "", atx, time.Now()) require.ErrorContains(t, err, "signer must marry itself") }) } @@ -1700,7 +1697,7 @@ func Test_MarryingMalicious(t *testing.T) { atxHandler.mtortoise.EXPECT().OnMalfeasance(sig.NodeID()) atxHandler.mtortoise.EXPECT().OnMalfeasance(otherSig.NodeID()) - _, err := atxHandler.processATX(context.Background(), "", atx, codec.MustEncode(atx), time.Now()) + _, err := atxHandler.processATX(context.Background(), "", atx, time.Now()) require.NoError(t, err) equiv, err := identities.EquivocationSet(atxHandler.cdb, sig.NodeID()) diff --git a/activation/malfeasance_test.go b/activation/malfeasance_test.go index bc41510598..a90ac6c5de 100644 --- a/activation/malfeasance_test.go +++ b/activation/malfeasance_test.go @@ -33,7 +33,7 @@ func createIdentity(tb testing.TB, db sql.Executor, sig *signing.EdSigner) { atx.SetReceived(time.Now()) atx.SetID(types.RandomATXID()) atx.TickCount = 1 - require.NoError(tb, atxs.Add(db, atx)) + require.NoError(tb, atxs.Add(db, atx, types.AtxBlob{})) } type testMalfeasanceHandler struct { diff --git a/activation/post_test.go b/activation/post_test.go index de51d599df..b3e24eabda 100644 --- a/activation/post_test.go +++ b/activation/post_test.go @@ -282,7 +282,7 @@ func TestPostSetupManager_findCommitmentAtx_UsesLatestAtx(t *testing.T) { } atx.SetID(types.RandomATXID()) atx.SetReceived(time.Now()) - require.NoError(t, atxs.Add(mgr.db, atx)) + require.NoError(t, atxs.Add(mgr.db, atx, types.AtxBlob{})) mgr.atxsdata.AddFromAtx(atx, false) commitmentAtx, err := mgr.findCommitmentAtx(context.Background()) @@ -333,7 +333,7 @@ func TestPostSetupManager_getCommitmentAtx_getsCommitmentAtxFromInitialAtx(t *te atx.SetID(types.RandomATXID()) atx.SetReceived(time.Now()) - require.NoError(t, atxs.Add(mgr.cdb, atx)) + require.NoError(t, atxs.Add(mgr.cdb, atx, types.AtxBlob{})) atxid, err := mgr.commitmentAtx(context.Background(), mgr.opts.DataDir, signer.NodeID()) require.NoError(t, err) diff --git a/activation/validation_test.go b/activation/validation_test.go index 520384c006..a96c977f6e 100644 --- a/activation/validation_test.go +++ b/activation/validation_test.go @@ -14,7 +14,6 @@ import ( "go.uber.org/mock/gomock" "github.com/spacemeshos/go-spacemesh/activation/wire" - "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" @@ -487,13 +486,13 @@ func TestVerifyChainDeps(t *testing.T) { invalidAtx.Sign(signer) vInvalidAtx := toAtx(t, invalidAtx) vInvalidAtx.SetValidity(types.Invalid) - require.NoError(t, atxs.Add(db, vInvalidAtx)) + require.NoError(t, atxs.Add(db, vInvalidAtx, invalidAtx.Blob())) t.Run("invalid prev ATX", func(t *testing.T) { atx := newChainedActivationTxV1(t, invalidAtx, goldenATXID) atx.Sign(signer) vAtx := toAtx(t, atx) - require.NoError(t, atxs.Add(db, vAtx)) + require.NoError(t, atxs.Add(db, vAtx, atx.Blob())) ctrl := gomock.NewController(t) v := NewMockPostVerifier(ctrl) @@ -508,7 +507,7 @@ func TestVerifyChainDeps(t *testing.T) { atx := newInitialATXv1(t, invalidAtx.ID()) atx.Sign(signer) vAtx := toAtx(t, atx) - require.NoError(t, atxs.Add(db, vAtx)) + require.NoError(t, atxs.Add(db, vAtx, atx.Blob())) ctrl := gomock.NewController(t) v := NewMockPostVerifier(ctrl) @@ -525,7 +524,7 @@ func TestVerifyChainDeps(t *testing.T) { atx.Sign(signer) atx.CommitmentATXID = &commitmentAtxID vAtx := toAtx(t, atx) - require.NoError(t, atxs.Add(db, vAtx)) + require.NoError(t, atxs.Add(db, vAtx, atx.Blob())) ctrl := gomock.NewController(t) v := NewMockPostVerifier(ctrl) @@ -539,7 +538,7 @@ func TestVerifyChainDeps(t *testing.T) { atx := newInitialATXv1(t, invalidAtx.ID()) atx.Sign(signer) vAtx := toAtx(t, atx) - require.NoError(t, atxs.Add(db, vAtx)) + require.NoError(t, atxs.Add(db, vAtx, atx.Blob())) ctrl := gomock.NewController(t) v := NewMockPostVerifier(ctrl) @@ -552,7 +551,7 @@ func TestVerifyChainDeps(t *testing.T) { atx := newInitialATXv1(t, invalidAtx.ID()) atx.Sign(signer) vAtx := toAtx(t, atx) - require.NoError(t, atxs.Add(db, vAtx)) + require.NoError(t, atxs.Add(db, vAtx, atx.Blob())) ctrl := gomock.NewController(t) v := NewMockPostVerifier(ctrl) @@ -566,7 +565,7 @@ func TestVerifyChainDeps(t *testing.T) { atx := newInitialATXv1(t, goldenATXID) atx.Sign(signer) vAtx := toAtx(t, atx) - require.NoError(t, atxs.Add(db, vAtx)) + require.NoError(t, atxs.Add(db, vAtx, atx.Blob())) ctrl := gomock.NewController(t) v := NewMockPostVerifier(ctrl) @@ -584,13 +583,9 @@ func TestVerifyChainDeps(t *testing.T) { atx := &types.ActivationTx{ PublishEpoch: watx.PublishEpoch, SmesherID: watx.SmesherID, - AtxBlob: types.AtxBlob{ - Blob: codec.MustEncode(watx), - Version: types.AtxV2, - }, } atx.SetID(watx.ID()) - require.NoError(t, atxs.Add(db, atx)) + require.NoError(t, atxs.Add(db, atx, watx.Blob())) v := NewMockPostVerifier(gomock.NewController(t)) expectedPost := (*shared.Proof)(wire.PostFromWireV1(&watx.NiPosts[0].Posts[0].Post)) @@ -602,20 +597,16 @@ func TestVerifyChainDeps(t *testing.T) { t.Run("non-initial V2 ATX", func(t *testing.T) { initialAtx := newInitialATXv1(t, goldenATXID) initialAtx.Sign(signer) - require.NoError(t, atxs.Add(db, toAtx(t, initialAtx))) + require.NoError(t, atxs.Add(db, toAtx(t, initialAtx), initialAtx.Blob())) watx := newSoloATXv2(t, initialAtx.PublishEpoch+1, initialAtx.ID(), initialAtx.ID()) watx.Sign(signer) atx := &types.ActivationTx{ PublishEpoch: watx.PublishEpoch, SmesherID: watx.SmesherID, - AtxBlob: types.AtxBlob{ - Blob: codec.MustEncode(watx), - Version: types.AtxV2, - }, } atx.SetID(watx.ID()) - require.NoError(t, atxs.Add(db, atx)) + require.NoError(t, atxs.Add(db, atx, watx.Blob())) v := NewMockPostVerifier(gomock.NewController(t)) expectedPost := (*shared.Proof)(wire.PostFromWireV1(&watx.NiPosts[0].Posts[0].Post)) @@ -654,7 +645,7 @@ func TestVerifyChainDepsAfterCheckpoint(t *testing.T) { atx := newChainedActivationTxV1(t, checkpointedAtx, checkpointedAtx.ID()) atx.Sign(signer) vAtx := toAtx(t, atx) - require.NoError(t, atxs.Add(db, vAtx)) + require.NoError(t, atxs.Add(db, vAtx, atx.Blob())) ctrl := gomock.NewController(t) v := NewMockPostVerifier(ctrl) diff --git a/activation/verify_state_test.go b/activation/verify_state_test.go index b0e46a78c8..26446d5926 100644 --- a/activation/verify_state_test.go +++ b/activation/verify_state_test.go @@ -34,7 +34,7 @@ func Test_CheckPrevATXs(t *testing.T) { }) atx1.Sign(sig) vAtx1 := toAtx(t, atx1) - require.NoError(t, atxs.Add(db, vAtx1)) + require.NoError(t, atxs.Add(db, vAtx1, atx1.Blob())) atx2 := newInitialATXv1(t, goldenATXID, func(atx *wire.ActivationTxV1) { atx.PrevATXID = prevATXID @@ -42,7 +42,7 @@ func Test_CheckPrevATXs(t *testing.T) { }) atx2.Sign(sig) vAtx2 := toAtx(t, atx2) - require.NoError(t, atxs.Add(db, vAtx2)) + require.NoError(t, atxs.Add(db, vAtx2, atx2.Blob())) // create 100 random ATXs that are not malicious for i := 0; i < 100; i++ { @@ -55,7 +55,7 @@ func Test_CheckPrevATXs(t *testing.T) { }) atx.Sign(otherSig) vAtx := toAtx(t, atx) - require.NoError(t, atxs.Add(db, vAtx)) + require.NoError(t, atxs.Add(db, vAtx, atx.Blob())) } // Act diff --git a/activation/wire/wire_v1.go b/activation/wire/wire_v1.go index e0fe9506a3..d76e343ab0 100644 --- a/activation/wire/wire_v1.go +++ b/activation/wire/wire_v1.go @@ -21,7 +21,8 @@ type ActivationTxV1 struct { SmesherID types.NodeID Signature types.EdSignature - id types.ATXID + id types.ATXID + blob []byte } // InnerActivationTxV1 is a set of all of an ATX's fields, except the signature. To generate the ATX signature, this @@ -123,6 +124,26 @@ func (atx *ActivationTxV1) SignedBytes() []byte { return data } +func (atx *ActivationTxV1) Blob() types.AtxBlob { + if len(atx.blob) == 0 { + atx.blob = codec.MustEncode(atx) + } + return types.AtxBlob{ + Blob: atx.blob, + Version: types.AtxV1, + } +} + +func DecodeAtxV1(blob []byte) (*ActivationTxV1, error) { + atx := &ActivationTxV1{ + blob: blob, + } + if err := codec.Decode(blob, atx); err != nil { + return nil, err + } + return atx, nil +} + func (atx *ActivationTxV1) HashInnerBytes() (result types.Hash32) { h := hash.GetHasher() defer hash.PutHasher(h) @@ -171,7 +192,7 @@ func NIPostChallengeToWireV1(c *types.NIPostChallenge) *NIPostChallengeV1 { } } -func ActivationTxFromWireV1(atx *ActivationTxV1, blob ...byte) *types.ActivationTx { +func ActivationTxFromWireV1(atx *ActivationTxV1) *types.ActivationTx { result := &types.ActivationTx{ PublishEpoch: atx.PublishEpoch, Sequence: atx.Sequence, @@ -180,13 +201,6 @@ func ActivationTxFromWireV1(atx *ActivationTxV1, blob ...byte) *types.Activation Coinbase: atx.Coinbase, NumUnits: atx.NumUnits, SmesherID: atx.SmesherID, - AtxBlob: types.AtxBlob{ - Version: types.AtxV1, - Blob: blob, - }, - } - if len(blob) == 0 { - result.AtxBlob.Blob = codec.MustEncode(atx) } result.SetID(atx.ID()) diff --git a/activation/wire/wire_v2.go b/activation/wire/wire_v2.go index d439ffa20d..7cf29c842e 100644 --- a/activation/wire/wire_v2.go +++ b/activation/wire/wire_v2.go @@ -42,13 +42,34 @@ type ActivationTxV2 struct { Signature types.EdSignature // cached fields to avoid repeated calculations - id types.ATXID + id types.ATXID + blob []byte } func (atx *ActivationTxV2) SignedBytes() []byte { return atx.ID().Bytes() } +func (atx *ActivationTxV2) Blob() types.AtxBlob { + if len(atx.blob) == 0 { + atx.blob = codec.MustEncode(atx) + } + return types.AtxBlob{ + Blob: atx.blob, + Version: types.AtxV2, + } +} + +func DecodeAtxV2(blob []byte) (*ActivationTxV2, error) { + atx := &ActivationTxV2{ + blob: blob, + } + if err := codec.Decode(blob, atx); err != nil { + return nil, err + } + return atx, nil +} + func (atx *ActivationTxV2) merkleTree(tree *merkle.Tree) { publishEpoch := make([]byte, 4) binary.LittleEndian.PutUint32(publishEpoch, atx.PublishEpoch.Uint32()) diff --git a/api/grpcserver/admin_service_test.go b/api/grpcserver/admin_service_test.go index 526b5420af..e68222adbc 100644 --- a/api/grpcserver/admin_service_test.go +++ b/api/grpcserver/admin_service_test.go @@ -37,7 +37,7 @@ func newAtx(tb testing.TB, db *sql.Database) { atx.SetID(types.RandomATXID()) atx.SmesherID = types.BytesToNodeID(types.RandomBytes(20)) atx.SetReceived(time.Now().Local()) - require.NoError(tb, atxs.Add(db, atx)) + require.NoError(tb, atxs.Add(db, atx, types.AtxBlob{})) require.NoError(tb, atxs.SetUnits(db, atx.ID(), atx.SmesherID, atx.NumUnits)) } diff --git a/api/grpcserver/grpcserver_test.go b/api/grpcserver/grpcserver_test.go index e4a47fc971..bf9b7c3769 100644 --- a/api/grpcserver/grpcserver_test.go +++ b/api/grpcserver/grpcserver_test.go @@ -2503,7 +2503,7 @@ func TestMeshService_EpochStream(t *testing.T) { all := createAtxs(t, epoch, atxids) var expected, got []types.ATXID for i, vatx := range all { - require.NoError(t, atxs.Add(db, vatx)) + require.NoError(t, atxs.Add(db, vatx, types.AtxBlob{})) if i%2 == 0 { require.NoError(t, identities.SetMalicious(db, vatx.SmesherID, []byte("bad"), time.Now())) } else { diff --git a/api/grpcserver/v2alpha1/activation_test.go b/api/grpcserver/v2alpha1/activation_test.go index c88d572831..70b97330fe 100644 --- a/api/grpcserver/v2alpha1/activation_test.go +++ b/api/grpcserver/v2alpha1/activation_test.go @@ -29,7 +29,7 @@ func TestActivationService_List(t *testing.T) { for i := range activations { atx := gen.Next() vAtx := fixture.ToAtx(t, atx) - require.NoError(t, atxs.Add(db, vAtx)) + require.NoError(t, atxs.Add(db, vAtx, atx.Blob())) activations[i] = *vAtx } @@ -112,7 +112,7 @@ func TestActivationStreamService_Stream(t *testing.T) { for i := range activations { atx := gen.Next() vAtx := fixture.ToAtx(t, atx) - require.NoError(t, atxs.Add(db, vAtx)) + require.NoError(t, atxs.Add(db, vAtx, atx.Blob())) activations[i] = *vAtx } @@ -153,8 +153,9 @@ func TestActivationStreamService_Stream(t *testing.T) { gen = fixture.NewAtxsGenerator().WithEpochs(start, 10) var streamed []*events.ActivationTx for i := 0; i < n; i++ { - atx := fixture.ToAtx(t, gen.Next()) - require.NoError(t, atxs.Add(db, atx)) + watx := gen.Next() + atx := fixture.ToAtx(t, watx) + require.NoError(t, atxs.Add(db, atx, watx.Blob())) streamed = append(streamed, &events.ActivationTx{ActivationTx: atx}) } @@ -221,7 +222,7 @@ func TestActivationService_ActivationsCount(t *testing.T) { for i := range epoch3ATXs { atx := genEpoch3.Next() vatx := fixture.ToAtx(t, atx) - require.NoError(t, atxs.Add(db, vatx)) + require.NoError(t, atxs.Add(db, vatx, atx.Blob())) epoch3ATXs[i] = *vatx } @@ -231,7 +232,7 @@ func TestActivationService_ActivationsCount(t *testing.T) { for i := range epoch5ATXs { atx := genEpoch5.Next() vatx := fixture.ToAtx(t, atx) - require.NoError(t, atxs.Add(db, vatx)) + require.NoError(t, atxs.Add(db, vatx, atx.Blob())) epoch5ATXs[i] = *vatx } diff --git a/atxsdata/warmup_test.go b/atxsdata/warmup_test.go index 2b82935000..67fa598140 100644 --- a/atxsdata/warmup_test.go +++ b/atxsdata/warmup_test.go @@ -49,7 +49,7 @@ func TestWarmup(t *testing.T) { gatx(types.ATXID{3, 3}, 3, types.NodeID{3}, nonce), } for i := range data { - require.NoError(t, atxs.Add(db, &data[i])) + require.NoError(t, atxs.Add(db, &data[i], types.AtxBlob{})) } require.NoError(t, layers.SetApplied(db, applied, types.BlockID{1})) @@ -75,7 +75,7 @@ func TestWarmup(t *testing.T) { db := sql.InMemory() nonce := types.VRFPostIndex(1) data := gatx(types.ATXID{1, 1}, 1, types.NodeID{1}, nonce) - require.NoError(t, atxs.Add(db, &data)) + require.NoError(t, atxs.Add(db, &data, types.AtxBlob{})) exec := mocks.NewMockExecutor(gomock.NewController(t)) call := 0 diff --git a/beacon/beacon_test.go b/beacon/beacon_test.go index bdc1c54fa7..5e3943d32d 100644 --- a/beacon/beacon_test.go +++ b/beacon/beacon_test.go @@ -127,7 +127,7 @@ func createATX( atx.SetReceived(received) atx.SetID(types.RandomATXID()) - require.NoError(tb, atxs.Add(db, &atx)) + require.NoError(tb, atxs.Add(db, &atx, types.AtxBlob{})) return atx.ID() } diff --git a/checkpoint/recovery_collecting_deps_test.go b/checkpoint/recovery_collecting_deps_test.go index 3da864eafb..1df6ca0a9b 100644 --- a/checkpoint/recovery_collecting_deps_test.go +++ b/checkpoint/recovery_collecting_deps_test.go @@ -8,7 +8,6 @@ import ( "golang.org/x/exp/maps" "github.com/spacemeshos/go-spacemesh/activation/wire" - "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/fixture" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" @@ -30,7 +29,7 @@ func TestCollectingDeps(t *testing.T) { }, SmesherID: types.RandomNodeID(), } - require.NoError(t, atxs.Add(db, fixture.ToAtx(t, marriageATX))) + require.NoError(t, atxs.Add(db, fixture.ToAtx(t, marriageATX), marriageATX.Blob())) mAtxID := marriageATX.ID() watx := &wire.ActivationTxV2{ @@ -39,15 +38,11 @@ func TestCollectingDeps(t *testing.T) { MarriageATX: &mAtxID, } atx := &types.ActivationTx{ - AtxBlob: types.AtxBlob{ - Version: types.AtxV2, - Blob: codec.MustEncode(watx), - }, SmesherID: watx.SmesherID, } atx.SetID(watx.ID()) atx.SetReceived(time.Now()) - require.NoError(t, atxs.Add(db, atx)) + require.NoError(t, atxs.Add(db, atx, watx.Blob())) // marry the two IDs err := identities.SetMarriage(db, marriageATX.SmesherID, &identities.MarriageData{ATX: marriageATX.ID()}) diff --git a/checkpoint/recovery_test.go b/checkpoint/recovery_test.go index bd171d9d60..d5dc54a03a 100644 --- a/checkpoint/recovery_test.go +++ b/checkpoint/recovery_test.go @@ -273,46 +273,45 @@ func validateAndPreserveData( for _, dep := range deps { var atx wire.ActivationTxV1 require.NoError(tb, codec.Decode(dep.Blob, &atx)) - vatx := wire.ActivationTxFromWireV1(&atx, dep.Blob...) - mclock.EXPECT().CurrentLayer().Return(vatx.PublishEpoch.FirstLayer()) + mclock.EXPECT().CurrentLayer().Return(atx.PublishEpoch.FirstLayer()) mfetch.EXPECT().RegisterPeerHashes(gomock.Any(), gomock.Any()) mfetch.EXPECT().GetPoetProof(gomock.Any(), gomock.Any()) - if vatx.PrevATXID == types.EmptyATXID { + if atx.PrevATXID == types.EmptyATXID { mvalidator.EXPECT(). InitialNIPostChallengeV1(&atx.NIPostChallengeV1, gomock.Any(), goldenAtx). AnyTimes() mvalidator.EXPECT().Post( gomock.Any(), - vatx.SmesherID, - *vatx.CommitmentATX, + atx.SmesherID, + *atx.CommitmentATXID, wire.PostFromWireV1(atx.InitialPost), gomock.Any(), - vatx.NumUnits, + atx.NumUnits, gomock.Any(), ) mvalidator.EXPECT().VRFNonce( - vatx.SmesherID, - *vatx.CommitmentATX, - (uint64)(vatx.VRFNonce), + atx.SmesherID, + *atx.CommitmentATXID, + *atx.VRFNonce, atx.NIPost.PostMetadata.LabelsPerUnit, - vatx.NumUnits, + atx.NumUnits, ) } else { mvalidator.EXPECT().NIPostChallengeV1( &atx.NIPostChallengeV1, gomock.Cond(func(prev any) bool { return prev.(*types.ActivationTx).ID() == atx.PrevATXID }), - vatx.SmesherID, + atx.SmesherID, ) } - mvalidator.EXPECT().PositioningAtx(atx.PositioningATXID, cdb, goldenAtx, vatx.PublishEpoch) + mvalidator.EXPECT().PositioningAtx(atx.PositioningATXID, cdb, goldenAtx, atx.PublishEpoch) mvalidator.EXPECT(). - NIPost(gomock.Any(), vatx.SmesherID, gomock.Any(), gomock.Any(), gomock.Any(), vatx.NumUnits, gomock.Any()). + NIPost(gomock.Any(), atx.SmesherID, gomock.Any(), gomock.Any(), gomock.Any(), atx.NumUnits, gomock.Any()). Return(uint64(1111111), nil) mvalidator.EXPECT().IsVerifyingFullPost().AnyTimes().Return(true) mreceiver.EXPECT().OnAtx(gomock.Any()) mtrtl.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()) - require.NoError(tb, atxHandler.HandleSyncedAtx(context.Background(), vatx.ID().Hash32(), "self", dep.Blob)) + require.NoError(tb, atxHandler.HandleSyncedAtx(context.Background(), atx.ID().Hash32(), "self", dep.Blob)) } } @@ -353,12 +352,9 @@ func newChainedAtx( } watx.Signature = sig.Sign(signing.ATX, watx.SignedBytes()) - atx := wire.ActivationTxFromWireV1(watx) - atx.SetReceived(time.Now().Local()) - return &checkpoint.AtxDep{ - ID: atx.ID(), - PublishEpoch: atx.PublishEpoch, + ID: watx.ID(), + PublishEpoch: watx.PublishEpoch, Blob: codec.MustEncode(watx), } } @@ -628,19 +624,15 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_IncludePending(t *testing.T) { require.NoError(t, oldDB.Close()) // write pending nipost challenge to simulate a pending atx still waiting for poet proof. - var atx wire.ActivationTxV1 - require.NoError(t, codec.Decode(vAtxs1[len(vAtxs1)-2].Blob, &atx)) - prevAtx1 := wire.ActivationTxFromWireV1(&atx) - atx = wire.ActivationTxV1{} - require.NoError(t, codec.Decode(vAtxs1[len(vAtxs1)-1].Blob, &atx)) - posAtx1 := wire.ActivationTxFromWireV1(&atx) - - atx = wire.ActivationTxV1{} - require.NoError(t, codec.Decode(vAtxs2[len(vAtxs1)-2].Blob, &atx)) - prevAtx2 := wire.ActivationTxFromWireV1(&atx) - atx = wire.ActivationTxV1{} - require.NoError(t, codec.Decode(vAtxs2[len(vAtxs1)-1].Blob, &atx)) - posAtx2 := wire.ActivationTxFromWireV1(&atx) + var prevAtx1 wire.ActivationTxV1 + require.NoError(t, codec.Decode(vAtxs1[len(vAtxs1)-2].Blob, &prevAtx1)) + var posAtx1 wire.ActivationTxV1 + require.NoError(t, codec.Decode(vAtxs1[len(vAtxs1)-1].Blob, &posAtx1)) + + var prevAtx2 wire.ActivationTxV1 + require.NoError(t, codec.Decode(vAtxs2[len(vAtxs1)-2].Blob, &prevAtx2)) + var posAtx2 wire.ActivationTxV1 + require.NoError(t, codec.Decode(vAtxs2[len(vAtxs1)-1].Blob, &posAtx2)) localDB, err := localsql.Open("file:" + filepath.Join(cfg.DataDir, cfg.LocalDbFile)) require.NoError(t, err) @@ -818,14 +810,13 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_DepIsGolden(t *testing.T) { require.NotNil(t, oldDB) vAtxs, proofs := createAtxChain(t, sig) // make the first one from the previous snapshot - var atx wire.ActivationTxV1 - require.NoError(t, codec.Decode(vAtxs[0].Blob, &atx)) - golden := wire.ActivationTxFromWireV1(&atx, vAtxs[0].Blob...) + var golden wire.ActivationTxV1 + require.NoError(t, codec.Decode(vAtxs[0].Blob, &golden)) require.NoError(t, atxs.AddCheckpointed(oldDB, &atxs.CheckpointAtx{ ID: golden.ID(), Epoch: golden.PublishEpoch, - CommitmentATX: *golden.CommitmentATX, - VRFNonce: golden.VRFNonce, + CommitmentATX: *golden.CommitmentATXID, + VRFNonce: types.VRFPostIndex(*golden.VRFNonce), NumUnits: golden.NumUnits, SmesherID: golden.SmesherID, Sequence: golden.Sequence, @@ -968,7 +959,7 @@ func TestRecover_OwnAtxInCheckpoint(t *testing.T) { oldDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, oldDB) - require.NoError(t, atxs.Add(oldDB, atx)) + require.NoError(t, atxs.Add(oldDB, atx, types.AtxBlob{})) require.NoError(t, oldDB.Close()) preserve, err := checkpoint.Recover(ctx, zaptest.NewLogger(t), afero.NewOsFs(), cfg) diff --git a/checkpoint/runner_test.go b/checkpoint/runner_test.go index 676b993f20..f7009c24ec 100644 --- a/checkpoint/runner_test.go +++ b/checkpoint/runner_test.go @@ -257,7 +257,7 @@ func createMesh(t testing.TB, db *sql.Database, miners []miner, accts []*types.A t.Helper() for _, miner := range miners { for _, atx := range miner.atxs { - require.NoError(t, atxs.Add(db, atx)) + require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) require.NoError(t, atxs.SetUnits(db, atx.ID(), atx.SmesherID, atx.NumUnits)) } if proof := miner.malfeasanceProof; len(proof) > 0 { diff --git a/cmd/bootstrapper/generator_test.go b/cmd/bootstrapper/generator_test.go index a0e96e72e4..b28e3b7f56 100644 --- a/cmd/bootstrapper/generator_test.go +++ b/cmd/bootstrapper/generator_test.go @@ -55,7 +55,7 @@ func createAtxs(tb testing.TB, db sql.Executor, epoch types.EpochID, atxids []ty } atx.SetID(id) atx.SetReceived(time.Now()) - require.NoError(tb, atxs.Add(db, atx)) + require.NoError(tb, atxs.Add(db, atx, types.AtxBlob{})) } } diff --git a/common/types/activation.go b/common/types/activation.go index 86579e29f8..eb6805d456 100644 --- a/common/types/activation.go +++ b/common/types/activation.go @@ -193,8 +193,6 @@ type ActivationTx struct { // for at least the first few years. Weight uint64 - AtxBlob - golden bool id ATXID // non-exported cache of the ATXID received time.Time // time received by node, gossiped or synced diff --git a/datastore/store_test.go b/datastore/store_test.go index cc52b1415e..c89f01dd09 100644 --- a/datastore/store_test.go +++ b/datastore/store_test.go @@ -168,7 +168,7 @@ func TestBlobStore_GetATXBlob(t *testing.T) { _, err = getBytes(ctx, bs, datastore.ATXDB, atx.ID()) require.ErrorIs(t, err, datastore.ErrNotFound) - require.NoError(t, atxs.Add(db, vAtx)) + require.NoError(t, atxs.Add(db, vAtx, atx.Blob())) has, err = bs.Has(datastore.ATXDB, atx.ID().Bytes()) require.NoError(t, err) @@ -176,10 +176,10 @@ func TestBlobStore_GetATXBlob(t *testing.T) { got, err := getBytes(ctx, bs, datastore.ATXDB, atx.ID()) require.NoError(t, err) - var gotA wire.ActivationTxV1 - codec.MustDecode(got, &gotA) + gotA, err := wire.DecodeAtxV1(got) + require.NoError(t, err) require.Equal(t, atx.ID(), gotA.ID()) - require.Equal(t, atx, &gotA) + require.Equal(t, atx, gotA) _, err = getBytes(ctx, bs, datastore.BallotDB, atx.ID()) require.ErrorIs(t, err, datastore.ErrNotFound) diff --git a/fetch/handler_test.go b/fetch/handler_test.go index 87c48f89bc..5cdc2afc88 100644 --- a/fetch/handler_test.go +++ b/fetch/handler_test.go @@ -262,7 +262,7 @@ func TestHandleMeshHashReq(t *testing.T) { } } -func newAtx(t *testing.T, published types.EpochID) *types.ActivationTx { +func newAtx(t *testing.T, published types.EpochID) (*types.ActivationTx, types.AtxBlob) { t.Helper() nonce := uint64(123) signer, err := signing.NewEdSigner() @@ -278,7 +278,7 @@ func newAtx(t *testing.T, published types.EpochID) *types.ActivationTx { }, } atx.Sign(signer) - return fixture.ToAtx(t, atx) + return fixture.ToAtx(t, atx), atx.Blob() } func TestHandleEpochInfoReq(t *testing.T) { @@ -304,8 +304,8 @@ func TestHandleEpochInfoReq(t *testing.T) { var expected EpochData if !tc.missingData { for i := 0; i < 10; i++ { - vatx := newAtx(t, epoch) - require.NoError(t, atxs.Add(th.cdb, vatx)) + vatx, blob := newAtx(t, epoch) + require.NoError(t, atxs.Add(th.cdb, vatx, blob)) expected.AtxIDs = append(expected.AtxIDs, vatx.ID()) } } @@ -353,8 +353,8 @@ func testHandleEpochInfoReqWithQueryCache( var expected EpochData for i := 0; i < 10; i++ { - vatx := newAtx(t, epoch) - require.NoError(t, atxs.Add(th.cdb, vatx)) + vatx, blob := newAtx(t, epoch) + require.NoError(t, atxs.Add(th.cdb, vatx, blob)) atxs.AtxAdded(th.cdb, vatx) expected.AtxIDs = append(expected.AtxIDs, vatx.ID()) } @@ -372,8 +372,8 @@ func testHandleEpochInfoReqWithQueryCache( } // Add another ATX which should be appended to the cached slice - vatx := newAtx(t, epoch) - require.NoError(t, atxs.Add(th.cdb, vatx)) + vatx, blob := newAtx(t, epoch) + require.NoError(t, atxs.Add(th.cdb, vatx, blob)) atxs.AtxAdded(th.cdb, vatx) expected.AtxIDs = append(expected.AtxIDs, vatx.ID()) require.Equal(t, 23, qc.QueryCount()) diff --git a/fetch/p2p_test.go b/fetch/p2p_test.go index 5557bbcea1..b10412d7fa 100644 --- a/fetch/p2p_test.go +++ b/fetch/p2p_test.go @@ -165,8 +165,8 @@ func createP2PFetch( func (tpf *testP2PFetch) createATXs(epoch types.EpochID) []types.ATXID { atxIDs := make([]types.ATXID, 10) for i := range atxIDs { - atx := newAtx(tpf.t, epoch) - require.NoError(tpf.t, atxs.Add(tpf.serverCDB, atx)) + atx, blob := newAtx(tpf.t, epoch) + require.NoError(tpf.t, atxs.Add(tpf.serverCDB, atx, blob)) atxIDs[i] = atx.ID() } return atxIDs @@ -345,15 +345,15 @@ func TestP2PGetATXs(t *testing.T) { t, "database: no free connection", func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { epoch := types.EpochID(11) - atx := newAtx(tpf.t, epoch) - require.NoError(tpf.t, atxs.Add(tpf.serverCDB, atx)) + atx, blob := newAtx(tpf.t, epoch) + require.NoError(tpf.t, atxs.Add(tpf.serverCDB, atx, blob)) tpf.verifyGetHash( func() error { return tpf.clientFetch.GetAtxs( context.Background(), []types.ATXID{atx.ID()}) }, errStr, "atx", "hs/1", types.Hash32(atx.ID()), atx.ID().Bytes(), - atx.AtxBlob.Blob) + blob.Blob) }) } diff --git a/hare3/eligibility/oracle_test.go b/hare3/eligibility/oracle_test.go index 7c24a673b5..ff1e13bf22 100644 --- a/hare3/eligibility/oracle_test.go +++ b/hare3/eligibility/oracle_test.go @@ -155,7 +155,7 @@ func (t *testOracle) createActiveSet( func (t *testOracle) addAtx(atx *types.ActivationTx) { t.tb.Helper() - require.NoError(t.tb, atxs.Add(t.db, atx)) + require.NoError(t.tb, atxs.Add(t.db, atx, types.AtxBlob{})) t.atxsdata.AddFromAtx(atx, false) } @@ -903,7 +903,7 @@ func TestActiveSetMatrix(t *testing.T) { require.NoError(t, ballots.Add(oracle.db, &ballot)) } for _, atx := range tc.atxs { - require.NoError(t, atxs.Add(oracle.db, atx)) + require.NoError(t, atxs.Add(oracle.db, atx, types.AtxBlob{})) oracle.atxsdata.AddFromAtx(atx, false) } if tc.beacon != types.EmptyBeacon { diff --git a/hare3/hare_test.go b/hare3/hare_test.go index 8bb1ae6164..71de7f71c3 100644 --- a/hare3/hare_test.go +++ b/hare3/hare_test.go @@ -249,7 +249,7 @@ func (n *node) register(signer *signing.EdSigner) { } func (n *node) storeAtx(atx *types.ActivationTx) error { - if err := atxs.Add(n.db, atx); err != nil { + if err := atxs.Add(n.db, atx, types.AtxBlob{}); err != nil { return err } n.atxsdata.AddFromAtx(atx, false) @@ -914,7 +914,7 @@ func TestProposals(t *testing.T) { WithLogger(logtest.New(t).Zap()), ) for _, atx := range tc.atxs { - require.NoError(t, atxs.Add(db, &atx)) + require.NoError(t, atxs.Add(db, &atx, types.AtxBlob{})) atxsdata.AddFromAtx(&atx, false) } for _, proposal := range tc.proposals { diff --git a/hare3/malfeasance_test.go b/hare3/malfeasance_test.go index e3d1547b7f..0f2a0f1491 100644 --- a/hare3/malfeasance_test.go +++ b/hare3/malfeasance_test.go @@ -54,7 +54,7 @@ func createIdentity(tb testing.TB, db sql.Executor, sig *signing.EdSigner) { atx.SetReceived(time.Now()) atx.SetID(types.RandomATXID()) atx.TickCount = 1 - require.NoError(tb, atxs.Add(db, atx)) + require.NoError(tb, atxs.Add(db, atx, types.AtxBlob{})) } func TestHandler_Validate(t *testing.T) { diff --git a/mesh/executor_test.go b/mesh/executor_test.go index 01645d640f..a4937f19a2 100644 --- a/mesh/executor_test.go +++ b/mesh/executor_test.go @@ -80,7 +80,7 @@ func (t *testExecutor) createATX(epoch types.EpochID, cb types.Address) (types.A } atx.SetReceived(time.Now()) atx.SetID(types.RandomATXID()) - require.NoError(t.tb, atxs.Add(t.db, atx)) + require.NoError(t.tb, atxs.Add(t.db, atx, types.AtxBlob{})) t.atxsdata.AddFromAtx(atx, false) return atx.ID(), sig.NodeID() } diff --git a/mesh/malfeasance_test.go b/mesh/malfeasance_test.go index 1d673da5e1..8e4c607bf8 100644 --- a/mesh/malfeasance_test.go +++ b/mesh/malfeasance_test.go @@ -54,7 +54,7 @@ func createIdentity(tb testing.TB, db sql.Executor, sig *signing.EdSigner) { atx.SetReceived(time.Now()) atx.SetID(types.RandomATXID()) atx.TickCount = 1 - require.NoError(tb, atxs.Add(db, atx)) + require.NoError(tb, atxs.Add(db, atx, types.AtxBlob{})) } func TestHandler_Validate(t *testing.T) { diff --git a/miner/active_set_generator_test.go b/miner/active_set_generator_test.go index a3efc935b6..660952d03c 100644 --- a/miner/active_set_generator_test.go +++ b/miner/active_set_generator_test.go @@ -252,7 +252,7 @@ func TestActiveSetGenerate(t *testing.T) { config{networkDelay: tc.networkDelay, goodAtxPercent: tc.goodAtxPercent}, ) for _, atx := range tc.atxs { - require.NoError(t, atxs.Add(tester.db, atx)) + require.NoError(t, atxs.Add(tester.db, atx, types.AtxBlob{})) tester.atxsdata.AddFromAtx(atx, false) } for _, identity := range tc.malfeasent { diff --git a/miner/proposal_builder_test.go b/miner/proposal_builder_test.go index 0ab3093f05..6df1c62592 100644 --- a/miner/proposal_builder_test.go +++ b/miner/proposal_builder_test.go @@ -777,7 +777,7 @@ func TestBuild(t *testing.T) { ) } for _, atx := range step.atxs { - require.NoError(t, atxs.Add(db, atx)) + require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) atxsdata.AddFromAtx(atx, false) } for _, ballot := range step.ballots { diff --git a/proposals/handler_test.go b/proposals/handler_test.go index 2fbff099f3..d2d52e6293 100644 --- a/proposals/handler_test.go +++ b/proposals/handler_test.go @@ -245,7 +245,7 @@ func createAtx(t *testing.T, db *sql.Database, epoch types.EpochID, atxID types. } atx.SetID(atxID) atx.SetReceived(time.Now()) - require.NoError(t, atxs.Add(db, atx)) + require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) } func createBallot(t *testing.T, opts ...createBallotOpt) *types.Ballot { diff --git a/sql/atxs/atxs.go b/sql/atxs/atxs.go index 837624dd3e..6cff8db611 100644 --- a/sql/atxs/atxs.go +++ b/sql/atxs/atxs.go @@ -418,7 +418,7 @@ func NonceByID(db sql.Executor, id types.ATXID) (nonce types.VRFPostIndex, err e return nonce, err } -func Add(db sql.Executor, atx *types.ActivationTx) error { +func Add(db sql.Executor, atx *types.ActivationTx, blob types.AtxBlob) error { enc := func(stmt *sql.Statement) { stmt.BindBytes(1, atx.ID().Bytes()) stmt.BindInt64(2, int64(atx.PublishEpoch)) @@ -453,7 +453,7 @@ func Add(db sql.Executor, atx *types.ActivationTx) error { return fmt.Errorf("insert ATX ID %v: %w", atx.ID(), err) } - return AddBlob(db, atx.ID(), atx.Blob, atx.Version) + return AddBlob(db, atx.ID(), blob.Blob, blob.Version) } func AddBlob(db sql.Executor, id types.ATXID, blob []byte, version types.AtxVersion) error { diff --git a/sql/atxs/atxs_test.go b/sql/atxs/atxs_test.go index b856db4190..d38e783601 100644 --- a/sql/atxs/atxs_test.go +++ b/sql/atxs/atxs_test.go @@ -34,18 +34,14 @@ func TestGet(t *testing.T) { for i := 0; i < 3; i++ { sig, err := signing.NewEdSigner() require.NoError(t, err) - atx := newAtx(t, sig, withPublishEpoch(types.EpochID(i))) + atx, blob := newAtx(t, sig, withPublishEpoch(types.EpochID(i))) atxList = append(atxList, atx) - } - - for _, atx := range atxList { - require.NoError(t, atxs.Add(db, atx)) + require.NoError(t, atxs.Add(db, atx, blob)) } for _, want := range atxList { got, err := atxs.Get(db, want.ID()) require.NoError(t, err) - want.AtxBlob = types.AtxBlob{} require.Equal(t, want, got) } @@ -56,17 +52,12 @@ func TestGet(t *testing.T) { func TestAll(t *testing.T) { db := sql.InMemory() - atxList := make([]*types.ActivationTx, 0) + var expected []types.ATXID for i := 0; i < 3; i++ { sig, err := signing.NewEdSigner() require.NoError(t, err) - atx := newAtx(t, sig, withPublishEpoch(types.EpochID(i))) - atxList = append(atxList, atx) - } - - var expected []types.ATXID - for _, atx := range atxList { - require.NoError(t, atxs.Add(db, atx)) + atx, blob := newAtx(t, sig, withPublishEpoch(types.EpochID(i))) + require.NoError(t, atxs.Add(db, atx, blob)) expected = append(expected, atx.ID()) } @@ -82,14 +73,11 @@ func TestHasID(t *testing.T) { for i := 0; i < 3; i++ { sig, err := signing.NewEdSigner() require.NoError(t, err) - atx := newAtx(t, sig, withPublishEpoch(types.EpochID(i))) + atx, blob := newAtx(t, sig, withPublishEpoch(types.EpochID(i))) + require.NoError(t, atxs.Add(db, atx, blob)) atxList = append(atxList, atx) } - for _, atx := range atxList { - require.NoError(t, atxs.Add(db, atx)) - } - for _, atx := range atxList { has, err := atxs.Has(db, atx.ID()) require.NoError(t, err) @@ -111,8 +99,8 @@ func Test_IdentityExists(t *testing.T) { require.NoError(t, err) require.False(t, yes) - atx := newAtx(t, sig) - require.NoError(t, atxs.Add(db, atx)) + atx, blob := newAtx(t, sig) + require.NoError(t, atxs.Add(db, atx, blob)) yes, err = atxs.IdentityExists(db, sig.NodeID()) require.NoError(t, err) @@ -130,18 +118,16 @@ func TestGetFirstIDByNodeID(t *testing.T) { require.NoError(t, err) // Arrange - - atx1 := newAtx(t, sig1, withPublishEpoch(1)) - atx2 := newAtx(t, sig1, withPublishEpoch(2), withSequence(atx1.Sequence+1)) - atx3 := newAtx(t, sig2, withPublishEpoch(3)) - atx4 := newAtx(t, sig2, withPublishEpoch(4), withSequence(atx3.Sequence+1)) + atx1, _ := newAtx(t, sig1, withPublishEpoch(1)) + atx2, _ := newAtx(t, sig1, withPublishEpoch(2), withSequence(atx1.Sequence+1)) + atx3, _ := newAtx(t, sig2, withPublishEpoch(3)) + atx4, _ := newAtx(t, sig2, withPublishEpoch(4), withSequence(atx3.Sequence+1)) for _, atx := range []*types.ActivationTx{atx1, atx2, atx3, atx4} { - require.NoError(t, atxs.Add(db, atx)) + require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) } // Act & Assert - id1, err := atxs.GetFirstIDByNodeID(db, sig1.NodeID()) require.NoError(t, err) require.EqualValues(t, atx1.ID(), id1) @@ -164,15 +150,15 @@ func TestLatestN(t *testing.T) { sig3, err := signing.NewEdSigner() require.NoError(t, err) - atx1 := newAtx(t, sig1, withPublishEpoch(1), withSequence(0)) - atx2 := newAtx(t, sig1, withPublishEpoch(2), withSequence(1)) - atx3 := newAtx(t, sig2, withPublishEpoch(3), withSequence(1)) - atx4 := newAtx(t, sig2, withPublishEpoch(4), withSequence(2)) - atx5 := newAtx(t, sig2, withPublishEpoch(5), withSequence(3)) - atx6 := newAtx(t, sig3, withPublishEpoch(1), withSequence(0)) + atx1, _ := newAtx(t, sig1, withPublishEpoch(1), withSequence(0)) + atx2, _ := newAtx(t, sig1, withPublishEpoch(2), withSequence(1)) + atx3, _ := newAtx(t, sig2, withPublishEpoch(3), withSequence(1)) + atx4, _ := newAtx(t, sig2, withPublishEpoch(4), withSequence(2)) + atx5, _ := newAtx(t, sig2, withPublishEpoch(5), withSequence(3)) + atx6, _ := newAtx(t, sig3, withPublishEpoch(1), withSequence(0)) for _, atx := range []*types.ActivationTx{atx1, atx2, atx3, atx4, atx5, atx6} { - require.NoError(t, atxs.Add(db, atx)) + require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) require.NoError(t, atxs.SetUnits(db, atx.ID(), atx.SmesherID, atx.NumUnits)) } @@ -254,11 +240,11 @@ func TestGetByEpochAndNodeID(t *testing.T) { sig2, err := signing.NewEdSigner() require.NoError(t, err) - atx1 := newAtx(t, sig1, withPublishEpoch(1)) - atx2 := newAtx(t, sig2, withPublishEpoch(2)) + atx1, _ := newAtx(t, sig1, withPublishEpoch(1)) + atx2, _ := newAtx(t, sig2, withPublishEpoch(2)) for _, atx := range []*types.ActivationTx{atx1, atx2} { - require.NoError(t, atxs.Add(db, atx)) + require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) } // Act & Assert @@ -289,14 +275,13 @@ func TestGetLastIDByNodeID(t *testing.T) { require.NoError(t, err) // Arrange - - atx1 := newAtx(t, sig1, withPublishEpoch(1)) - atx2 := newAtx(t, sig1, withPublishEpoch(2), withSequence(atx1.Sequence+1)) - atx3 := newAtx(t, sig2, withPublishEpoch(3)) - atx4 := newAtx(t, sig2, withPublishEpoch(4), withSequence(atx3.Sequence+1)) + atx1, _ := newAtx(t, sig1, withPublishEpoch(1)) + atx2, _ := newAtx(t, sig1, withPublishEpoch(2), withSequence(atx1.Sequence+1)) + atx3, _ := newAtx(t, sig2, withPublishEpoch(3)) + atx4, _ := newAtx(t, sig2, withPublishEpoch(4), withSequence(atx3.Sequence+1)) for _, atx := range []*types.ActivationTx{atx1, atx2, atx3, atx4} { - require.NoError(t, atxs.Add(db, atx)) + require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) } // Act & Assert @@ -325,13 +310,13 @@ func TestGetIDByEpochAndNodeID(t *testing.T) { e2 := types.EpochID(2) e3 := types.EpochID(3) - atx1 := newAtx(t, sig1, withPublishEpoch(e1)) - atx2 := newAtx(t, sig1, withPublishEpoch(e2)) - atx3 := newAtx(t, sig2, withPublishEpoch(e2)) - atx4 := newAtx(t, sig2, withPublishEpoch(e3)) + atx1, _ := newAtx(t, sig1, withPublishEpoch(e1)) + atx2, _ := newAtx(t, sig1, withPublishEpoch(e2)) + atx3, _ := newAtx(t, sig2, withPublishEpoch(e2)) + atx4, _ := newAtx(t, sig2, withPublishEpoch(e3)) for _, atx := range []*types.ActivationTx{atx1, atx2, atx3, atx4} { - require.NoError(t, atxs.Add(db, atx)) + require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) } l1n1, err := atxs.GetIDByEpochAndNodeID(db, e1, sig1.NodeID()) @@ -370,13 +355,13 @@ func TestGetIDsByEpoch(t *testing.T) { e2 := types.EpochID(2) e3 := types.EpochID(3) - atx1 := newAtx(t, sig1, withPublishEpoch(e1)) - atx2 := newAtx(t, sig1, withPublishEpoch(e2)) - atx3 := newAtx(t, sig2, withPublishEpoch(e2)) - atx4 := newAtx(t, sig2, withPublishEpoch(e3)) + atx1, _ := newAtx(t, sig1, withPublishEpoch(e1)) + atx2, _ := newAtx(t, sig1, withPublishEpoch(e2)) + atx3, _ := newAtx(t, sig2, withPublishEpoch(e2)) + atx4, _ := newAtx(t, sig2, withPublishEpoch(e3)) for _, atx := range []*types.ActivationTx{atx1, atx2, atx3, atx4} { - require.NoError(t, atxs.Add(db, atx)) + require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) } ids1, err := atxs.GetIDsByEpoch(ctx, db, e1) @@ -406,15 +391,15 @@ func TestGetIDsByEpochCached(t *testing.T) { e2 := types.EpochID(2) e3 := types.EpochID(3) - atx1 := newAtx(t, sig1, withPublishEpoch(e1)) - atx2 := newAtx(t, sig1, withPublishEpoch(e2)) - atx3 := newAtx(t, sig2, withPublishEpoch(e2)) - atx4 := newAtx(t, sig2, withPublishEpoch(e3)) - atx5 := newAtx(t, sig2, withPublishEpoch(e3)) - atx6 := newAtx(t, sig2, withPublishEpoch(e3)) + atx1, _ := newAtx(t, sig1, withPublishEpoch(e1)) + atx2, _ := newAtx(t, sig1, withPublishEpoch(e2)) + atx3, _ := newAtx(t, sig2, withPublishEpoch(e2)) + atx4, _ := newAtx(t, sig2, withPublishEpoch(e3)) + atx5, _ := newAtx(t, sig2, withPublishEpoch(e3)) + atx6, _ := newAtx(t, sig2, withPublishEpoch(e3)) for _, atx := range []*types.ActivationTx{atx1, atx2, atx3, atx4} { - require.NoError(t, atxs.Add(db, atx)) + require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) atxs.AtxAdded(db, atx) } @@ -444,7 +429,7 @@ func TestGetIDsByEpochCached(t *testing.T) { } require.NoError(t, db.WithTx(context.Background(), func(tx *sql.Tx) error { - atxs.Add(tx, atx5) + atxs.Add(tx, atx5, types.AtxBlob{}) return nil })) atxs.AtxAdded(db, atx5) @@ -456,7 +441,7 @@ func TestGetIDsByEpochCached(t *testing.T) { require.Equal(t, 13, db.QueryCount()) // not incremented after Add require.Error(t, db.WithTx(context.Background(), func(tx *sql.Tx) error { - atxs.Add(tx, atx6) + atxs.Add(tx, atx6, types.AtxBlob{}) return errors.New("fail") // rollback })) @@ -475,8 +460,8 @@ func Test_IterateAtxsWithMalfeasance(t *testing.T) { for i := uint32(0); i < 20; i++ { sig, err := signing.NewEdSigner() require.NoError(t, err) - atx := newAtx(t, sig, withPublishEpoch(types.EpochID(i/4))) - require.NoError(t, atxs.Add(db, atx)) + atx, blob := newAtx(t, sig, withPublishEpoch(types.EpochID(i/4))) + require.NoError(t, atxs.Add(db, atx, blob)) malicious := (i % 2) == 0 m[atx.ID()] = malicious if malicious { @@ -505,8 +490,8 @@ func Test_IterateAtxIdsWithMalfeasance(t *testing.T) { for i := uint32(0); i < 20; i++ { sig, err := signing.NewEdSigner() require.NoError(t, err) - atx := newAtx(t, sig, withPublishEpoch(types.EpochID(i/4))) - require.NoError(t, atxs.Add(db, atx)) + atx, blob := newAtx(t, sig, withPublishEpoch(types.EpochID(i/4))) + require.NoError(t, atxs.Add(db, atx, blob)) malicious := (i % 2) == 0 m[atx.ID()] = malicious if malicious { @@ -534,11 +519,11 @@ func TestVRFNonce(t *testing.T) { sig, err := signing.NewEdSigner() require.NoError(t, err) - atx1 := newAtx(t, sig, withPublishEpoch(20), withNonce(333)) - require.NoError(t, atxs.Add(db, atx1)) + atx1, blob := newAtx(t, sig, withPublishEpoch(20), withNonce(333)) + require.NoError(t, atxs.Add(db, atx1, blob)) - atx2 := newAtx(t, sig, withPublishEpoch(50), withNonce(777), withPrevATXID(atx1.ID())) - require.NoError(t, atxs.Add(db, atx2)) + atx2, blob := newAtx(t, sig, withPublishEpoch(50), withNonce(777), withPrevATXID(atx1.ID())) + require.NoError(t, atxs.Add(db, atx2, blob)) // Act & Assert @@ -567,33 +552,30 @@ func TestLoadBlob(t *testing.T) { sig, err := signing.NewEdSigner() require.NoError(t, err) - atx1 := newAtx(t, sig, withPublishEpoch(1)) - atx1.AtxBlob.Blob = []byte("blob1") - atx1.AtxBlob.Version = types.AtxV1 - - require.NoError(t, atxs.Add(db, atx1)) + atx1, blob := newAtx(t, sig, withPublishEpoch(1)) + require.NoError(t, atxs.Add(db, atx1, blob)) var blob1 sql.Blob version, err := atxs.LoadBlob(ctx, db, atx1.ID().Bytes(), &blob1) require.NoError(t, err) require.Equal(t, types.AtxV1, version) - require.Equal(t, atx1.AtxBlob.Blob, blob1.Bytes) + require.Equal(t, blob.Blob, blob1.Bytes) blobSizes, err := atxs.GetBlobSizes(db, [][]byte{atx1.ID().Bytes()}) require.NoError(t, err) require.Equal(t, []int{len(blob1.Bytes)}, blobSizes) var blob2 sql.Blob - atx2 := newAtx(t, sig) - atx2.AtxBlob.Blob = []byte("blob2 of different size") - atx2.AtxBlob.Version = types.AtxV2 + atx2, blob := newAtx(t, sig) + blob.Blob = []byte("blob2 of different size") + blob.Version = types.AtxV2 - require.NoError(t, atxs.Add(db, atx2)) + require.NoError(t, atxs.Add(db, atx2, blob)) version, err = atxs.LoadBlob(ctx, db, atx2.ID().Bytes(), &blob2) require.NoError(t, err) - require.Equal(t, types.AtxV2, version) - require.Equal(t, atx2.AtxBlob.Blob, blob2.Bytes) + require.Equal(t, blob.Version, version) + require.Equal(t, blob.Blob, blob2.Bytes) blobSizes, err = atxs.GetBlobSizes(db, [][]byte{ atx1.ID().Bytes(), @@ -621,17 +603,16 @@ func TestLoadBlob_DefaultsToV1(t *testing.T) { sig, err := signing.NewEdSigner() require.NoError(t, err) - atx := newAtx(t, sig) - atx.AtxBlob.Blob = []byte("blob1") - atx.AtxBlob.Version = 0 + atx, blob := newAtx(t, sig) + blob.Version = 0 - require.NoError(t, atxs.Add(db, atx)) + require.NoError(t, atxs.Add(db, atx, blob)) - var blob sql.Blob - version, err := atxs.LoadBlob(context.Background(), db, atx.ID().Bytes(), &blob) + var b sql.Blob + version, err := atxs.LoadBlob(context.Background(), db, atx.ID().Bytes(), &b) require.NoError(t, err) require.Equal(t, types.AtxV1, version) - require.Equal(t, atx.AtxBlob.Blob, blob.Bytes) + require.Equal(t, blob.Blob, b.Bytes) } func TestGetBlobCached(t *testing.T) { @@ -640,16 +621,16 @@ func TestGetBlobCached(t *testing.T) { sig, err := signing.NewEdSigner() require.NoError(t, err) - atx := newAtx(t, sig, withPublishEpoch(1)) + atx, blob := newAtx(t, sig, withPublishEpoch(1)) - require.NoError(t, atxs.Add(db, atx)) + require.NoError(t, atxs.Add(db, atx, blob)) require.Equal(t, 2, db.QueryCount()) // insert atx + blob for i := 0; i < 3; i++ { var b sql.Blob _, err := atxs.LoadBlob(ctx, db, atx.ID().Bytes(), &b) require.NoError(t, err) - require.Equal(t, atx.Blob, b.Bytes) + require.Equal(t, blob.Blob, b.Bytes) require.Equal(t, 3, db.QueryCount()) } } @@ -659,29 +640,31 @@ func TestGetBlobCached(t *testing.T) { func TestGetBlobCached_CacheEntriesAreDistinct(t *testing.T) { db := sql.InMemory(sql.WithQueryCache(true)) - atx := types.ActivationTx{AtxBlob: types.AtxBlob{Blob: []byte("original blob")}} + atx := types.ActivationTx{} atx.SetID(types.RandomATXID()) - require.NoError(t, atxs.Add(db, &atx)) + blob := types.AtxBlob{Blob: []byte("original blob")} + require.NoError(t, atxs.Add(db, &atx, blob)) require.Equal(t, 2, db.QueryCount()) // insert atx + blob - blob := &sql.Blob{} - _, err := atxs.LoadBlob(context.Background(), db, atx.ID().Bytes(), blob) + b := &sql.Blob{} + _, err := atxs.LoadBlob(context.Background(), db, atx.ID().Bytes(), b) require.NoError(t, err) - require.Equal(t, atx.AtxBlob.Blob, blob.Bytes) + require.Equal(t, blob.Blob, b.Bytes) - atx2 := types.ActivationTx{AtxBlob: types.AtxBlob{Blob: []byte("other blob")}} + atx2 := types.ActivationTx{} atx2.SetID(types.RandomATXID()) - require.Less(t, len(atx2.AtxBlob.Blob), len(atx.AtxBlob.Blob)) - require.NoError(t, atxs.Add(db, &atx2)) + blob2 := types.AtxBlob{Blob: []byte("other blob")} + require.Less(t, len(blob2.Blob), len(blob.Blob)) + require.NoError(t, atxs.Add(db, &atx2, blob2)) // Loading atx2 doesn't overwrite the cached blob for atx1 - _, err = atxs.LoadBlob(context.Background(), db, atx2.ID().Bytes(), blob) + _, err = atxs.LoadBlob(context.Background(), db, atx2.ID().Bytes(), b) require.NoError(t, err) - require.Equal(t, atx2.AtxBlob.Blob, blob.Bytes) + require.Equal(t, blob2.Blob, b.Bytes) - _, err = atxs.LoadBlob(context.Background(), db, atx.ID().Bytes(), blob) + _, err = atxs.LoadBlob(context.Background(), db, atx.ID().Bytes(), b) require.NoError(t, err) - require.Equal(t, atx.AtxBlob.Blob, blob.Bytes) + require.Equal(t, blob.Blob, b.Bytes) } // Test that the cached blob is not shared with the caller @@ -690,18 +673,18 @@ func TestGetBlobCached_OverwriteSafety(t *testing.T) { db := sql.InMemory(sql.WithQueryCache(true)) atx := types.ActivationTx{} atx.SetID(types.RandomATXID()) - atx.AtxBlob.Blob = []byte("original blob") - require.NoError(t, atxs.Add(db, &atx)) + blob := types.AtxBlob{Blob: []byte("original blob")} + require.NoError(t, atxs.Add(db, &atx, blob)) require.Equal(t, 2, db.QueryCount()) // insert atx + blob var b sql.Blob // we will reuse the blob between queries _, err := atxs.LoadBlob(context.Background(), db, atx.ID().Bytes(), &b) require.NoError(t, err) - require.Equal(t, atx.AtxBlob.Blob, b.Bytes) + require.Equal(t, blob.Blob, b.Bytes) b.Bytes[0] = 'X' // modify the blob _, err = atxs.LoadBlob(context.Background(), db, atx.ID().Bytes(), &b) require.NoError(t, err) - require.Equal(t, atx.AtxBlob.Blob, b.Bytes) + require.Equal(t, blob.Blob, b.Bytes) } func TestCachedBlobEviction(t *testing.T) { @@ -718,13 +701,13 @@ func TestCachedBlobEviction(t *testing.T) { blobs := make([][]byte, 11) var b sql.Blob for n := range addedATXs { - atx := newAtx(t, sig, withPublishEpoch(1)) - require.NoError(t, atxs.Add(db, atx)) + atx, blob := newAtx(t, sig, withPublishEpoch(1)) + require.NoError(t, atxs.Add(db, atx, blob)) addedATXs[n] = atx - blobs[n] = atx.Blob + blobs[n] = blob.Blob _, err := atxs.LoadBlob(ctx, db, atx.ID().Bytes(), &b) require.NoError(t, err) - require.Equal(t, atx.Blob, b.Bytes) + require.Equal(t, blob.Blob, b.Bytes) } // insert atx + insert blob + load blob each time @@ -751,7 +734,7 @@ func TestCheckpointATX(t *testing.T) { sig, err := signing.NewEdSigner() require.NoError(t, err) - atx := newAtx(t, sig, withPublishEpoch(3), withSequence(4)) + atx, _ := newAtx(t, sig, withPublishEpoch(3), withSequence(4)) catx := &atxs.CheckpointAtx{ ID: atx.ID(), Epoch: atx.PublishEpoch, @@ -801,14 +784,13 @@ func TestAdd(t *testing.T) { sig, err := signing.NewEdSigner() require.NoError(t, err) - atx := newAtx(t, sig, withPublishEpoch(1)) + atx, blob := newAtx(t, sig, withPublishEpoch(1)) - require.NoError(t, atxs.Add(db, atx)) - require.ErrorIs(t, atxs.Add(db, atx), sql.ErrObjectExists) + require.NoError(t, atxs.Add(db, atx, blob)) + require.ErrorIs(t, atxs.Add(db, atx, blob), sql.ErrObjectExists) got, err := atxs.Get(db, atx.ID()) require.NoError(t, err) - atx.AtxBlob = types.AtxBlob{} require.Equal(t, atx, got) } @@ -838,7 +820,13 @@ func withPrevATXID(id types.ATXID) createAtxOpt { } } -func newAtx(t testing.TB, signer *signing.EdSigner, opts ...createAtxOpt) *types.ActivationTx { +func withCoinbase(addr types.Address) createAtxOpt { + return func(atx *types.ActivationTx) { + atx.Coinbase = addr + } +} + +func newAtx(t testing.TB, signer *signing.EdSigner, opts ...createAtxOpt) (*types.ActivationTx, types.AtxBlob) { nonce := uint64(123) watx := &wire.ActivationTxV1{ InnerActivationTxV1: wire.InnerActivationTxV1{ @@ -855,7 +843,7 @@ func newAtx(t testing.TB, signer *signing.EdSigner, opts ...createAtxOpt) *types for _, opt := range opts { opt(atx) } - return atx + return atx, watx.Blob() } type header struct { @@ -881,7 +869,7 @@ func createAtx(tb testing.TB, db *sql.Database, hdr header) (types.ATXID, *signi full.SetReceived(time.Now()) full.SetID(types.RandomATXID()) - require.NoError(tb, atxs.Add(db, full)) + require.NoError(tb, atxs.Add(db, full, types.AtxBlob{})) if hdr.malicious { require.NoError(tb, identities.SetMalicious(db, sig.NodeID(), []byte("bad"), time.Now())) } @@ -1030,7 +1018,7 @@ func TestLatest(t *testing.T) { } full.SetReceived(time.Now()) full.SetID(types.ATXID{byte(i)}) - require.NoError(t, atxs.Add(db, full)) + require.NoError(t, atxs.Add(db, full, types.AtxBlob{})) } latest, err := atxs.LatestEpoch(db) require.NoError(t, err) @@ -1047,21 +1035,19 @@ func Test_PrevATXCollisions(t *testing.T) { // create two ATXs with the same PrevATXID prevATXID := types.RandomATXID() - atx1 := newAtx(t, sig, withPublishEpoch(1), withPrevATXID(prevATXID)) - atx2 := newAtx(t, sig, withPublishEpoch(2), withPrevATXID(prevATXID)) + atx1, blob1 := newAtx(t, sig, withPublishEpoch(1), withPrevATXID(prevATXID)) + atx2, blob2 := newAtx(t, sig, withPublishEpoch(2), withPrevATXID(prevATXID)) - require.NoError(t, atxs.Add(db, atx1)) - require.NoError(t, atxs.Add(db, atx2)) + require.NoError(t, atxs.Add(db, atx1, blob1)) + require.NoError(t, atxs.Add(db, atx2, blob2)) // verify that the ATXs were added got1, err := atxs.Get(db, atx1.ID()) require.NoError(t, err) - atx1.AtxBlob = types.AtxBlob{} require.Equal(t, atx1, got1) got2, err := atxs.Get(db, atx2.ID()) require.NoError(t, err) - atx2.AtxBlob = types.AtxBlob{} require.Equal(t, atx2, got2) // add 10 valid ATXs by 10 other smeshers @@ -1069,14 +1055,14 @@ func Test_PrevATXCollisions(t *testing.T) { otherSig, err := signing.NewEdSigner() require.NoError(t, err) - atx := newAtx(t, otherSig, withPublishEpoch(types.EpochID(i))) - require.NoError(t, atxs.Add(db, atx)) + atx, blob := newAtx(t, otherSig, withPublishEpoch(types.EpochID(i))) + require.NoError(t, atxs.Add(db, atx, blob)) - atx2 := newAtx(t, otherSig, + atx2, blob2 := newAtx(t, otherSig, withPublishEpoch(types.EpochID(i+1)), withPrevATXID(atx.ID()), ) - require.NoError(t, atxs.Add(db, atx2)) + require.NoError(t, atxs.Add(db, atx2, blob2)) } // get the collisions @@ -1102,8 +1088,8 @@ func TestCoinbase(t *testing.T) { db := sql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) - atx := newAtx(t, sig, func(a *types.ActivationTx) { a.Coinbase = types.Address{1, 2, 3} }) - require.NoError(t, atxs.Add(db, atx)) + atx, blob := newAtx(t, sig, withCoinbase(types.Address{1, 2, 3})) + require.NoError(t, atxs.Add(db, atx, blob)) cb, err := atxs.Coinbase(db, sig.NodeID()) require.NoError(t, err) require.Equal(t, atx.Coinbase, cb) @@ -1113,10 +1099,10 @@ func TestCoinbase(t *testing.T) { db := sql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) - atx1 := newAtx(t, sig, withPublishEpoch(1), func(a *types.ActivationTx) { a.Coinbase = types.Address{1, 2, 3} }) - atx2 := newAtx(t, sig, withPublishEpoch(2), func(a *types.ActivationTx) { a.Coinbase = types.Address{4, 5, 6} }) - require.NoError(t, atxs.Add(db, atx1)) - require.NoError(t, atxs.Add(db, atx2)) + atx1, blob1 := newAtx(t, sig, withPublishEpoch(1), withCoinbase(types.Address{1, 2, 3})) + atx2, blob2 := newAtx(t, sig, withPublishEpoch(2), withCoinbase(types.Address{4, 5, 6})) + require.NoError(t, atxs.Add(db, atx1, blob1)) + require.NoError(t, atxs.Add(db, atx2, blob2)) cb, err := atxs.Coinbase(db, sig.NodeID()) require.NoError(t, err) require.Equal(t, atx2.Coinbase, cb) diff --git a/sql/ballots/ballots_test.go b/sql/ballots/ballots_test.go index 624d648548..901161d0f4 100644 --- a/sql/ballots/ballots_test.go +++ b/sql/ballots/ballots_test.go @@ -163,7 +163,7 @@ func TestFirstInEpoch(t *testing.T) { sig, err := signing.NewEdSigner() require.NoError(t, err) atx := newAtx(sig, lid) - require.NoError(t, atxs.Add(db, atx)) + require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) got, err := FirstInEpoch(db, atx.ID(), 2) require.ErrorIs(t, err, sql.ErrNotFound) diff --git a/syncer/atxsync/atxsync_test.go b/syncer/atxsync/atxsync_test.go index 33f649c727..632721910e 100644 --- a/syncer/atxsync/atxsync_test.go +++ b/syncer/atxsync/atxsync_test.go @@ -101,16 +101,15 @@ func TestDownload(t *testing.T) { ctrl := gomock.NewController(t) fetcher := mocks.NewMockAtxFetcher(ctrl) for _, atx := range tc.existing { - require.NoError(t, atxs.Add(db, atx)) + require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) } for i := range tc.fetched { req := tc.fetched[i] fetcher.EXPECT(). GetAtxs(tc.ctx, req.request, gomock.Any()). - Times(1). DoAndReturn(func(_ context.Context, _ []types.ATXID, _ ...system.GetAtxOpt) error { for _, atx := range req.result { - require.NoError(t, atxs.Add(db, atx)) + require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) } return req.error }) diff --git a/syncer/atxsync/syncer_test.go b/syncer/atxsync/syncer_test.go index 79165f950d..8d9e3dc99b 100644 --- a/syncer/atxsync/syncer_test.go +++ b/syncer/atxsync/syncer_test.go @@ -90,7 +90,7 @@ func TestSyncer(t *testing.T) { GetAtxs(gomock.Any(), gomock.Any()). DoAndReturn(func(_ context.Context, ids []types.ATXID, _ ...system.GetAtxOpt) error { for _, id := range ids { - require.NoError(t, atxs.Add(tester.db, atx(id))) + require.NoError(t, atxs.Add(tester.db, atx(id), types.AtxBlob{})) } return nil }).AnyTimes() @@ -153,7 +153,7 @@ func TestSyncer(t *testing.T) { GetAtxs(gomock.Any(), gomock.Any()). DoAndReturn(func(_ context.Context, ids []types.ATXID, _ ...system.GetAtxOpt) error { for _, id := range ids { - require.NoError(t, atxs.Add(tester.db, atx(id))) + require.NoError(t, atxs.Add(tester.db, atx(id), types.AtxBlob{})) } return nil }).AnyTimes() @@ -185,7 +185,7 @@ func TestSyncer(t *testing.T) { for _, id := range ids { for _, good := range good.AtxIDs { if good == id { - require.NoError(t, atxs.Add(tester.db, atx(id))) + require.NoError(t, atxs.Add(tester.db, atx(id), types.AtxBlob{})) } } for _, bad := range bad.AtxIDs { diff --git a/tortoise/model/core.go b/tortoise/model/core.go index 04381a2aa1..88aa0cfa9d 100644 --- a/tortoise/model/core.go +++ b/tortoise/model/core.go @@ -174,7 +174,7 @@ func (c *core) OnMessage(m Messenger, event Message) { case MessageAtx: ev.Atx.BaseTickHeight = 1 ev.Atx.TickCount = 2 - atxs.Add(c.cdb, ev.Atx) + atxs.Add(c.cdb, ev.Atx, types.AtxBlob{}) malicious, err := c.cdb.IsMalicious(ev.Atx.SmesherID) if err != nil { c.logger.Fatal("failed is malicious lookup", zap.Error(err)) diff --git a/tortoise/sim/state.go b/tortoise/sim/state.go index 266997fec9..2a660e2cdc 100644 --- a/tortoise/sim/state.go +++ b/tortoise/sim/state.go @@ -44,7 +44,7 @@ func (s *State) OnBeacon(eid types.EpochID, beacon types.Beacon) { func (s *State) OnActivationTx(atx *types.ActivationTx) { // TODO: consider using actual values for malicious if needed s.Atxdata.AddFromAtx(atx, false) - if err := atxs.Add(s.DB, atx); err != nil { + if err := atxs.Add(s.DB, atx, types.AtxBlob{}); err != nil { s.logger.Panic("failed to add atx", zap.Error(err)) } } diff --git a/tortoise/threshold_test.go b/tortoise/threshold_test.go index 3b414ef06d..7fd299fa78 100644 --- a/tortoise/threshold_test.go +++ b/tortoise/threshold_test.go @@ -174,7 +174,7 @@ func TestReferenceHeight(t *testing.T) { } atx.SetID(types.ATXID{byte(i + 1)}) atx.SetReceived(time.Now()) - require.NoError(t, atxs.Add(db, atx)) + require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) } _, height, err := extractAtxsData(db, types.EpochID(tc.epoch)) require.NoError(t, err) diff --git a/tortoise/tortoise_test.go b/tortoise/tortoise_test.go index 9d1ec119c0..9318731824 100644 --- a/tortoise/tortoise_test.go +++ b/tortoise/tortoise_test.go @@ -479,7 +479,7 @@ func TestComputeExpectedWeight(t *testing.T) { } atx.SetID(types.RandomATXID()) atx.SetReceived(time.Now()) - require.NoError(t, atxs.Add(db, atx)) + require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) } for lid := tc.target.Add(1); !lid.After(tc.last); lid = lid.Add(1) { weight, _, err := extractAtxsData(db, lid.GetEpoch())