diff --git a/activation/e2e/checkpoint_merged_test.go b/activation/e2e/checkpoint_merged_test.go index 6f0675fc1c..f0425911ee 100644 --- a/activation/e2e/checkpoint_merged_test.go +++ b/activation/e2e/checkpoint_merged_test.go @@ -273,6 +273,12 @@ func Test_CheckpointAfterMerge(t *testing.T) { require.Equal(t, i, marriage.Index) } + checkpointedMerged, err := atxs.Get(newDB, mergedATX.ID()) + require.NoError(t, err) + require.True(t, checkpointedMerged.Golden()) + require.NotNil(t, checkpointedMerged.MarriageATX) + require.Equal(t, marriageATX.ID(), *checkpointedMerged.MarriageATX) + // 4. Spawn new ATX handler and builder using the new DB poetDb = activation.NewPoetDb(newDB, logger.Named("poetDb")) cdb = datastore.NewCachedDB(newDB, logger) diff --git a/activation/handler_v2.go b/activation/handler_v2.go index fecf03ef83..e013b9d45c 100644 --- a/activation/handler_v2.go +++ b/activation/handler_v2.go @@ -117,6 +117,7 @@ func (h *HandlerV2) processATX( atx := &types.ActivationTx{ PublishEpoch: watx.PublishEpoch, + MarriageATX: watx.MarriageATX, Coinbase: watx.Coinbase, BaseTickHeight: baseTickHeight, NumUnits: parts.effectiveUnits, @@ -658,6 +659,7 @@ func (h *HandlerV2) checkMalicious( tx *sql.Tx, watx *wire.ActivationTxV2, marrying []marriage, + ids []types.NodeID, ) error { malicious, err := identities.IsMalicious(tx, watx.SmesherID) if err != nil { @@ -675,6 +677,22 @@ func (h *HandlerV2) checkMalicious( return nil } + malicious, err = h.checkDoublePost(ctx, tx, watx, ids) + if err != nil { + return fmt.Errorf("checking double post: %w", err) + } + if malicious { + return nil + } + + malicious, err = h.checkDoubleMerge(ctx, tx, watx) + if err != nil { + return fmt.Errorf("checking double merge: %w", err) + } + if malicious { + return nil + } + // TODO(mafa): contextual validation: // 1. check double-publish = ID contributed post to two ATXs in the same epoch // 2. check previous ATX @@ -705,6 +723,66 @@ func (h *HandlerV2) checkDoubleMarry( return false, nil } +func (h *HandlerV2) checkDoublePost( + ctx context.Context, + tx *sql.Tx, + atx *wire.ActivationTxV2, + ids []types.NodeID, +) (bool, error) { + for _, id := range ids { + atxids, err := atxs.FindDoublePublish(tx, id, atx.PublishEpoch) + switch { + case errors.Is(err, sql.ErrNotFound): + continue + case err != nil: + return false, fmt.Errorf("searching for double publish: %w", err) + } + otherAtxId := slices.IndexFunc(atxids, func(other types.ATXID) bool { return other != atx.ID() }) + otherAtx := atxids[otherAtxId] + h.logger.Debug( + "found ID that has already contributed its PoST in this epoch", + zap.Stringer("node_id", id), + zap.Stringer("atx_id", atx.ID()), + zap.Stringer("other_atx_id", otherAtx), + zap.Uint32("epoch", atx.PublishEpoch.Uint32()), + ) + // TODO(mafa): finish proof + proof := &wire.ATXProof{ + ProofType: wire.DoublePublish, + } + return true, h.malPublisher.Publish(ctx, id, proof) + } + return false, nil +} + +func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, watx *wire.ActivationTxV2) (bool, error) { + if watx.MarriageATX == nil { + return false, nil + } + ids, err := atxs.MergeConflict(tx, *watx.MarriageATX, watx.PublishEpoch) + switch { + case errors.Is(err, sql.ErrNotFound): + return false, nil + case err != nil: + return false, fmt.Errorf("searching for ATXs with the same marriage ATX: %w", err) + } + otherIndex := slices.IndexFunc(ids, func(id types.ATXID) bool { return id != watx.ID() }) + other := ids[otherIndex] + + h.logger.Debug("second merged ATX for single marriage - creating malfeasance proof", + zap.Stringer("marriage_atx", *watx.MarriageATX), + zap.Stringer("atx", watx.ID()), + zap.Stringer("other_atx", other), + zap.Stringer("smesher_id", watx.SmesherID), + ) + + // TODO(mafa): finish proof + proof := &wire.ATXProof{ + ProofType: wire.DoubleMerge, + } + return true, h.malPublisher.Publish(ctx, watx.SmesherID, proof) +} + // Store an ATX in the DB. func (h *HandlerV2) storeAtx( ctx context.Context, @@ -752,7 +830,7 @@ func (h *HandlerV2) storeAtx( // TODO(mafa): don't store own ATX if it would mark the node as malicious // this probably needs to be done by validating and storing own ATXs eagerly and skipping validation in // the gossip handler (not sync!) - err := h.checkMalicious(ctx, tx, watx, marrying) + err := h.checkMalicious(ctx, tx, watx, marrying, maps.Keys(units)) if err != nil { return fmt.Errorf("check malicious: %w", err) } diff --git a/activation/handler_v2_test.go b/activation/handler_v2_test.go index 501efbb47b..06f52f9b69 100644 --- a/activation/handler_v2_test.go +++ b/activation/handler_v2_test.go @@ -614,18 +614,16 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) { func marryIDs( t testing.TB, atxHandler *v2TestHandler, - sig *signing.EdSigner, + signers []*signing.EdSigner, golden types.ATXID, - num int, ) (marriage *wire.ActivationTxV2, other []*wire.ActivationTxV2) { + sig := signers[0] mATX := newInitialATXv2(t, golden) mATX.Marriages = []wire.MarriageCertificate{{ Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), }} - for range num { - signer, err := signing.NewEdSigner() - require.NoError(t, err) + for _, signer := range signers[1:] { atx := atxHandler.createAndProcessInitial(t, signer) other = append(other, atx) mATX.Marriages = append(mATX.Marriages, wire.MarriageCertificate{ @@ -644,20 +642,27 @@ func marryIDs( func TestHandlerV2_ProcessMergedATX(t *testing.T) { t.Parallel() - golden := types.RandomATXID() - sig, err := signing.NewEdSigner() - require.NoError(t, err) + var ( + golden = types.RandomATXID() + signers []*signing.EdSigner + equivocationSet []types.NodeID + ) + for range 4 { + sig, err := signing.NewEdSigner() + require.NoError(t, err) + signers = append(signers, sig) + equivocationSet = append(equivocationSet, sig.NodeID()) + } + sig := signers[0] t.Run("happy case", func(t *testing.T) { atxHandler := newV2TestHandler(t, golden) // Marry IDs - mATX, otherATXs := marryIDs(t, atxHandler, sig, golden, 2) + mATX, otherATXs := marryIDs(t, atxHandler, signers, golden) previousATXs := []types.ATXID{mATX.ID()} - equivocationSet := []types.NodeID{sig.NodeID()} for _, atx := range otherATXs { previousATXs = append(previousATXs, atx.ID()) - equivocationSet = append(equivocationSet, atx.SmesherID) } // Process a merged ATX @@ -694,12 +699,10 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { atxHandler.tickSize = tickSize // Marry IDs - mATX, otherATXs := marryIDs(t, atxHandler, sig, golden, 4) + mATX, otherATXs := marryIDs(t, atxHandler, signers, golden) previousATXs := []types.ATXID{mATX.ID()} - equivocationSet := []types.NodeID{sig.NodeID()} for _, atx := range otherATXs { previousATXs = append(previousATXs, atx.ID()) - equivocationSet = append(equivocationSet, atx.SmesherID) } // Process a merged ATX @@ -765,12 +768,10 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { atxHandler := newV2TestHandler(t, golden) // Marry IDs - mATX, otherATXs := marryIDs(t, atxHandler, sig, golden, 2) + mATX, otherATXs := marryIDs(t, atxHandler, signers, golden) previousATXs := []types.ATXID{} - equivocationSet := []types.NodeID{sig.NodeID()} for _, atx := range otherATXs { previousATXs = append(previousATXs, atx.ID()) - equivocationSet = append(equivocationSet, atx.SmesherID) } // Process a merged ATX @@ -802,12 +803,10 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { atxHandler := newV2TestHandler(t, golden) // Marry IDs - mATX, otherATXs := marryIDs(t, atxHandler, sig, golden, 1) + mATX, otherATXs := marryIDs(t, atxHandler, signers[:2], golden) previousATXs := []types.ATXID{mATX.ID()} - equivocationSet := []types.NodeID{sig.NodeID()} for _, atx := range otherATXs { previousATXs = append(previousATXs, atx.ID()) - equivocationSet = append(equivocationSet, atx.SmesherID) } // Process a merged ATX @@ -836,12 +835,10 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { atxHandler := newV2TestHandler(t, golden) // Marry IDs - mATX, otherATXs := marryIDs(t, atxHandler, sig, golden, 1) + mATX, otherATXs := marryIDs(t, atxHandler, signers[:2], golden) previousATXs := []types.ATXID{mATX.ID()} - equivocationSet := []types.NodeID{sig.NodeID()} for _, atx := range otherATXs { previousATXs = append(previousATXs, atx.ID()) - equivocationSet = append(equivocationSet, atx.SmesherID) } // Process a merged ATX @@ -868,11 +865,7 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { atxHandler := newV2TestHandler(t, golden) // Marry IDs - mATX, otherATXs := marryIDs(t, atxHandler, sig, golden, 1) - equivocationSet := []types.NodeID{sig.NodeID()} - for _, atx := range otherATXs { - equivocationSet = append(equivocationSet, atx.SmesherID) - } + mATX, _ := marryIDs(t, atxHandler, signers, golden) prev := atxs.CheckpointAtx{ Epoch: mATX.PublishEpoch + 1, @@ -932,6 +925,97 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { err = atxHandler.processATX(context.Background(), "", merged, time.Now()) require.ErrorIs(t, err, pubsub.ErrValidationReject) }) + t.Run("publishing two merged ATXs from one marriage set is malfeasance", func(t *testing.T) { + atxHandler := newV2TestHandler(t, golden) + + // Marry 4 IDs + mATX, otherATXs := marryIDs(t, atxHandler, signers, golden) + previousATXs := []types.ATXID{mATX.ID()} + for _, atx := range otherATXs { + previousATXs = append(previousATXs, atx.ID()) + } + + // Process a merged ATX for 2 IDs + merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) + merged.NiPosts[0].Posts = []wire.SubPostV2{} + for i := range equivocationSet[:2] { + post := wire.SubPostV2{ + MarriageIndex: uint32(i), + PrevATXIndex: uint32(i), + NumUnits: 4, + } + merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + } + + mATXID := mATX.ID() + + merged.MarriageATX = &mATXID + merged.PreviousATXs = []types.ATXID{mATX.ID(), otherATXs[0].ID()} + merged.Sign(sig) + + atxHandler.expectMergedAtxV2(merged, equivocationSet, []uint64{100}) + err := atxHandler.processATX(context.Background(), "", merged, time.Now()) + require.NoError(t, err) + + // Process a second merged ATX for the same equivocation set, but different IDs + merged = newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) + merged.NiPosts[0].Posts = []wire.SubPostV2{} + for i := range equivocationSet[:2] { + post := wire.SubPostV2{ + MarriageIndex: uint32(i + 2), + PrevATXIndex: uint32(i), + NumUnits: 4, + } + merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + } + + mATXID = mATX.ID() + merged.MarriageATX = &mATXID + merged.PreviousATXs = []types.ATXID{otherATXs[1].ID(), otherATXs[2].ID()} + merged.Sign(signers[2]) + + atxHandler.expectMergedAtxV2(merged, equivocationSet, []uint64{100}) + atxHandler.mMalPublish.EXPECT().Publish(gomock.Any(), merged.SmesherID, gomock.Any()) + err = atxHandler.processATX(context.Background(), "", merged, time.Now()) + require.NoError(t, err) + }) + t.Run("publishing two merged ATXs (one checkpointed)", func(t *testing.T) { + atxHandler := newV2TestHandler(t, golden) + + mATX, otherATXs := marryIDs(t, atxHandler, signers, golden) + mATXID := mATX.ID() + + // Insert checkpointed merged ATX + checkpointedATX := &atxs.CheckpointAtx{ + Epoch: mATX.PublishEpoch + 2, + ID: types.RandomATXID(), + SmesherID: signers[0].NodeID(), + MarriageATX: &mATXID, + } + require.NoError(t, atxs.AddCheckpointed(atxHandler.cdb, checkpointedATX)) + + // create and process another merged ATX + merged := newSoloATXv2(t, checkpointedATX.Epoch, mATX.ID(), golden) + merged.NiPosts[0].Posts = []wire.SubPostV2{} + for i := range equivocationSet[2:] { + post := wire.SubPostV2{ + MarriageIndex: uint32(i + 2), + PrevATXIndex: uint32(i), + NumUnits: 4, + } + merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + } + + merged.MarriageATX = &mATXID + merged.PreviousATXs = []types.ATXID{otherATXs[1].ID(), otherATXs[2].ID()} + merged.Sign(signers[2]) + atxHandler.expectMergedAtxV2(merged, equivocationSet, []uint64{100}) + // TODO: this could be syntactically validated as all nodes in the network + // should already have the checkpointed merged ATX. + atxHandler.mMalPublish.EXPECT().Publish(gomock.Any(), merged.SmesherID, gomock.Any()) + err := atxHandler.processATX(context.Background(), "", merged, time.Now()) + require.NoError(t, err) + }) } func TestCollectDeps_AtxV2(t *testing.T) { @@ -1730,6 +1814,64 @@ func Test_MarryingMalicious(t *testing.T) { } } +func TestContextualValidation_DoublePost(t *testing.T) { + t.Parallel() + golden := types.RandomATXID() + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + atxHandler := newV2TestHandler(t, golden) + + // marry + otherSig, err := signing.NewEdSigner() + require.NoError(t, err) + othersAtx := atxHandler.createAndProcessInitial(t, otherSig) + + mATX := newInitialATXv2(t, golden) + mATX.Marriages = []wire.MarriageCertificate{ + { + Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + { + ReferenceAtx: othersAtx.ID(), + Signature: otherSig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + } + mATX.Sign(sig) + + atxHandler.expectInitialAtxV2(mATX) + err = atxHandler.processATX(context.Background(), "", mATX, time.Now()) + require.NoError(t, err) + + // publish merged + merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) + post := wire.SubPostV2{ + MarriageIndex: 1, + NumUnits: othersAtx.TotalNumUnits(), + PrevATXIndex: 1, + } + merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + + mATXID := mATX.ID() + merged.MarriageATX = &mATXID + + merged.PreviousATXs = []types.ATXID{mATX.ID(), othersAtx.ID()} + merged.Sign(sig) + + atxHandler.expectMergedAtxV2(merged, []types.NodeID{sig.NodeID(), otherSig.NodeID()}, []uint64{poetLeaves}) + err = atxHandler.processATX(context.Background(), "", merged, time.Now()) + require.NoError(t, err) + + // The otherSig tries to publish alone in the same epoch. + // This is malfeasance as it tries include his PoST twice. + doubled := newSoloATXv2(t, merged.PublishEpoch, othersAtx.ID(), othersAtx.ID()) + doubled.Sign(otherSig) + atxHandler.expectAtxV2(doubled) + atxHandler.mMalPublish.EXPECT().Publish(gomock.Any(), otherSig.NodeID(), gomock.Any()) + err = atxHandler.processATX(context.Background(), "", doubled, time.Now()) + require.NoError(t, err) +} + func Test_CalculatingUnits(t *testing.T) { t.Parallel() t.Run("units on 1 nipost must not overflow", func(t *testing.T) { diff --git a/activation/wire/malfeasance.go b/activation/wire/malfeasance.go index c00bbcd984..d8e60a4127 100644 --- a/activation/wire/malfeasance.go +++ b/activation/wire/malfeasance.go @@ -12,6 +12,7 @@ type ProofType byte const ( DoublePublish ProofType = iota + 1 DoubleMarry + DoubleMerge InvalidPost ) diff --git a/api/grpcserver/v2alpha1/network.go b/api/grpcserver/v2alpha1/network.go index 60d4ad29ce..f9454394cc 100644 --- a/api/grpcserver/v2alpha1/network.go +++ b/api/grpcserver/v2alpha1/network.go @@ -11,17 +11,19 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/config" ) const ( Network = "network_v2alpha1" ) -func NewNetworkService(genesisTime time.Time, genesisID types.Hash20, layerDuration time.Duration) *NetworkService { +func NewNetworkService(genesisTime time.Time, config *config.Config) *NetworkService { return &NetworkService{ genesisTime: genesisTime, - genesisID: genesisID, - layerDuration: layerDuration, + genesisID: config.Genesis.GenesisID(), + layerDuration: config.LayerDuration, + labelsPerUnit: config.POST.LabelsPerUnit, } } @@ -29,6 +31,7 @@ type NetworkService struct { genesisTime time.Time genesisID types.Hash20 layerDuration time.Duration + labelsPerUnit uint64 } func (s *NetworkService) RegisterService(server *grpc.Server) { @@ -54,5 +57,6 @@ func (s *NetworkService) Info(context.Context, Hrp: types.NetworkHRP(), EffectiveGenesisLayer: types.GetEffectiveGenesis().Uint32(), LayersPerEpoch: types.GetLayersPerEpoch(), + LabelsPerUnit: s.labelsPerUnit, }, nil } diff --git a/api/grpcserver/v2alpha1/network_test.go b/api/grpcserver/v2alpha1/network_test.go index 989ab92a1e..0e87e3a7d1 100644 --- a/api/grpcserver/v2alpha1/network_test.go +++ b/api/grpcserver/v2alpha1/network_test.go @@ -9,13 +9,15 @@ import ( "github.com/stretchr/testify/require" "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/config" ) func TestNetworkService_Info(t *testing.T) { ctx := context.Background() genesis := time.Unix(genTimeUnix, 0) + c := config.DefaultTestConfig() - svc := NewNetworkService(genesis, genesisID, layerDuration) + svc := NewNetworkService(genesis, &c) cfg, cleanup := launchServer(t, svc) t.Cleanup(cleanup) @@ -27,10 +29,11 @@ func TestNetworkService_Info(t *testing.T) { require.NoError(t, err) require.Equal(t, genesis.UTC(), info.GenesisTime.AsTime().UTC()) - require.Equal(t, layerDuration, info.LayerDuration.AsDuration()) - require.Equal(t, genesisID.Bytes(), info.GenesisId) + require.Equal(t, c.LayerDuration, info.LayerDuration.AsDuration()) + require.Equal(t, c.Genesis.GenesisID().Bytes(), info.GenesisId) require.Equal(t, types.NetworkHRP(), info.Hrp) require.Equal(t, types.GetEffectiveGenesis().Uint32(), info.EffectiveGenesisLayer) require.Equal(t, types.GetLayersPerEpoch(), info.LayersPerEpoch) + require.Equal(t, c.POST.LabelsPerUnit, info.LabelsPerUnit) }) } diff --git a/api/grpcserver/v2alpha1/v2alpha1_test.go b/api/grpcserver/v2alpha1/v2alpha1_test.go index 27d5d9f7ac..52ba08f556 100644 --- a/api/grpcserver/v2alpha1/v2alpha1_test.go +++ b/api/grpcserver/v2alpha1/v2alpha1_test.go @@ -11,7 +11,6 @@ import ( "google.golang.org/grpc/credentials/insecure" "github.com/spacemeshos/go-spacemesh/api/grpcserver" - "github.com/spacemeshos/go-spacemesh/common/types" ) const ( @@ -19,8 +18,6 @@ const ( layerDuration = 10 * time.Second ) -var genesisID = types.Hash20{} - func launchServer(tb testing.TB, services ...grpcserver.ServiceAPI) (grpcserver.Config, func()) { cfg := grpcserver.DefaultTestConfig() grpc, err := grpcserver.NewWithServices(cfg.PublicListener, zaptest.NewLogger(tb).Named("grpc"), cfg, services) diff --git a/bootstrap/updater.go b/bootstrap/updater.go index 04f53cbdd2..9641aefb96 100644 --- a/bootstrap/updater.go +++ b/bootstrap/updater.go @@ -29,6 +29,7 @@ import ( "github.com/santhosh-tekuri/jsonschema/v5" "github.com/spf13/afero" + "go.uber.org/zap" "golang.org/x/sync/errgroup" "github.com/spacemeshos/go-spacemesh/common/types" @@ -71,7 +72,7 @@ func DefaultConfig() Config { type Updater struct { cfg Config - logger log.Log + logger *zap.Logger clock layerClock fs afero.Fs client *http.Client @@ -92,7 +93,7 @@ func WithConfig(cfg Config) Opt { } } -func WithLogger(logger log.Log) Opt { +func WithLogger(logger *zap.Logger) Opt { return func(u *Updater) { u.logger = logger } @@ -113,7 +114,7 @@ func WithHttpClient(c *http.Client) Opt { func New(clock layerClock, opts ...Opt) *Updater { u := &Updater{ cfg: DefaultConfig(), - logger: log.NewNop(), + logger: zap.NewNop(), clock: clock, fs: afero.NewOsFs(), client: &http.Client{}, @@ -149,7 +150,7 @@ func (u *Updater) Load(ctx context.Context) error { if err = u.updateAndNotify(ctx, verified); err != nil { return err } - u.logger.With().Info("loaded bootstrap file", log.Inline(verified)) + u.logger.Info("loaded bootstrap file", zap.Inline(verified)) u.addUpdate(verified.Data.Epoch, verified.Persisted[len(verified.Persisted)-suffixLen:]) } return nil @@ -165,14 +166,14 @@ func (u *Updater) Start() error { if err := u.Load(ctx); err != nil { return err } - u.logger.With().Info("start listening to update", - log.String("source", u.cfg.URL), - log.Duration("interval", u.cfg.Interval), + u.logger.Info("start listening to update", + zap.String("source", u.cfg.URL), + zap.Duration("interval", u.cfg.Interval), ) for { if err := u.DoIt(ctx); err != nil { updateFailureCount.Add(1) - u.logger.With().Debug("failed to get bootstrap update", log.Err(err)) + u.logger.Debug("failed to get bootstrap update", zap.Error(err)) } select { case <-u.stop: @@ -233,10 +234,10 @@ func (u *Updater) DoIt(ctx context.Context) error { current := u.clock.CurrentLayer().GetEpoch() defer func() { if err := u.prune(current); err != nil { - u.logger.With().Error("failed to prune", - log.Context(ctx), - log.Uint32("current epoch", current.Uint32()), - log.Err(err), + u.logger.Error("failed to prune", + log.ZContext(ctx), + zap.Uint32("current epoch", current.Uint32()), + zap.Error(err), ) } }() @@ -291,7 +292,7 @@ func (u *Updater) checkEpochUpdate( return nil, false, fmt.Errorf("persist bootstrap %s: %w", filename, err) } verified.Persisted = filename - u.logger.WithContext(ctx).With().Info("new bootstrap file", log.Inline(verified)) + u.logger.Info("new bootstrap file", log.ZContext(ctx), zap.Inline(verified)) if err = u.updateAndNotify(ctx, verified); err != nil { return verified, false, err } diff --git a/bootstrap/updater_test.go b/bootstrap/updater_test.go index 2cc98e4c4d..4a3b712b36 100644 --- a/bootstrap/updater_test.go +++ b/bootstrap/updater_test.go @@ -15,11 +15,11 @@ import ( "github.com/spf13/afero" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" + "go.uber.org/zap/zaptest" "golang.org/x/sync/errgroup" "github.com/spacemeshos/go-spacemesh/bootstrap" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/log/logtest" ) const ( @@ -196,7 +196,7 @@ func TestLoad(t *testing.T) { updater := bootstrap.New( mc, bootstrap.WithConfig(cfg), - bootstrap.WithLogger(logtest.New(t)), + bootstrap.WithLogger(zaptest.NewLogger(t)), bootstrap.WithFilesystem(fs), ) ch, err := updater.Subscribe() @@ -248,7 +248,7 @@ func TestLoadedNotDownloadedAgain(t *testing.T) { updater := bootstrap.New( mc, bootstrap.WithConfig(cfg), - bootstrap.WithLogger(logtest.New(t)), + bootstrap.WithLogger(zaptest.NewLogger(t)), bootstrap.WithFilesystem(fs), ) ch, err := updater.Subscribe() @@ -276,7 +276,7 @@ func TestStartClose(t *testing.T) { updater := bootstrap.New( mc, bootstrap.WithConfig(cfg), - bootstrap.WithLogger(logtest.New(t)), + bootstrap.WithLogger(zaptest.NewLogger(t)), bootstrap.WithFilesystem(fs), ) ch, err := updater.Subscribe() @@ -322,7 +322,7 @@ func TestPrune(t *testing.T) { updater := bootstrap.New( mc, bootstrap.WithConfig(cfg), - bootstrap.WithLogger(logtest.New(t)), + bootstrap.WithLogger(zaptest.NewLogger(t)), bootstrap.WithFilesystem(fs), bootstrap.WithHttpClient(ts.Client()), ) @@ -384,7 +384,7 @@ func TestDoIt(t *testing.T) { updater := bootstrap.New( mc, bootstrap.WithConfig(cfg), - bootstrap.WithLogger(logtest.New(t)), + bootstrap.WithLogger(zaptest.NewLogger(t)), bootstrap.WithFilesystem(fs), bootstrap.WithHttpClient(ts.Client()), ) @@ -421,7 +421,7 @@ func TestEmptyResponse(t *testing.T) { updater := bootstrap.New( mc, bootstrap.WithConfig(cfg), - bootstrap.WithLogger(logtest.New(t)), + bootstrap.WithLogger(zaptest.NewLogger(t)), bootstrap.WithFilesystem(fs), bootstrap.WithHttpClient(ts.Client()), ) @@ -492,7 +492,7 @@ func TestGetInvalidUpdate(t *testing.T) { updater := bootstrap.New( mc, bootstrap.WithConfig(cfg), - bootstrap.WithLogger(logtest.New(t)), + bootstrap.WithLogger(zaptest.NewLogger(t)), bootstrap.WithFilesystem(fs), bootstrap.WithHttpClient(ts.Client()), ) @@ -526,7 +526,7 @@ func TestNoNewUpdate(t *testing.T) { updater := bootstrap.New( mc, bootstrap.WithConfig(cfg), - bootstrap.WithLogger(logtest.New(t)), + bootstrap.WithLogger(zaptest.NewLogger(t)), bootstrap.WithFilesystem(fs), bootstrap.WithHttpClient(ts.Client()), ) @@ -647,7 +647,7 @@ func TestRequiredEpochs(t *testing.T) { updater := bootstrap.New( mc, bootstrap.WithConfig(cfg), - bootstrap.WithLogger(logtest.New(t)), + bootstrap.WithLogger(zaptest.NewLogger(t)), bootstrap.WithFilesystem(fs), bootstrap.WithHttpClient(ts.Client()), ) @@ -676,7 +676,7 @@ func TestIntegration(t *testing.T) { updater := bootstrap.New( mc, bootstrap.WithConfig(cfg), - bootstrap.WithLogger(logtest.New(t)), + bootstrap.WithLogger(zaptest.NewLogger(t)), bootstrap.WithFilesystem(fs), ) ch, err := updater.Subscribe() @@ -709,7 +709,7 @@ func TestClose(t *testing.T) { updater := bootstrap.New( mc, bootstrap.WithConfig(cfg), - bootstrap.WithLogger(logtest.New(t)), + bootstrap.WithLogger(zaptest.NewLogger(t)), bootstrap.WithFilesystem(fs), ) ch, err := updater.Subscribe() diff --git a/checkpoint/recovery.go b/checkpoint/recovery.go index 97a77d5468..31df1cdff1 100644 --- a/checkpoint/recovery.go +++ b/checkpoint/recovery.go @@ -356,6 +356,10 @@ func checkpointData(fs afero.Fs, file string, newGenesis types.LayerID) (*recove cAtx.ID = types.ATXID(types.BytesToHash(atx.ID)) cAtx.Epoch = types.EpochID(atx.Epoch) cAtx.CommitmentATX = types.ATXID(types.BytesToHash(atx.CommitmentAtx)) + if len(atx.MarriageAtx) == 32 { + marriageATXID := types.ATXID(atx.MarriageAtx) + cAtx.MarriageATX = &marriageATXID + } cAtx.SmesherID = types.BytesToNodeID(atx.PublicKey) cAtx.NumUnits = atx.NumUnits cAtx.VRFNonce = types.VRFPostIndex(atx.VrfNonce) diff --git a/checkpoint/recovery_test.go b/checkpoint/recovery_test.go index 5fe2533300..07669b47bf 100644 --- a/checkpoint/recovery_test.go +++ b/checkpoint/recovery_test.go @@ -908,7 +908,8 @@ func TestRecover_OwnAtxInCheckpoint(t *testing.T) { require.NoError(t, err) atxid, err := hex.DecodeString("98e47278c1f58acfd2b670a730f28898f74eb140482a07b91ff81f9ff0b7d9f4") require.NoError(t, err) - atx := newAtx(types.ATXID(atxid), types.EmptyATXID, nil, 3, 1, 0, nid) + atx := &types.ActivationTx{SmesherID: types.NodeID(nid)} + atx.SetID(types.ATXID(atxid)) cfg := &checkpoint.RecoverConfig{ GoldenAtx: goldenAtx, diff --git a/checkpoint/runner.go b/checkpoint/runner.go index 2b72e236b0..7039aa54c0 100644 --- a/checkpoint/runner.go +++ b/checkpoint/runner.go @@ -78,10 +78,15 @@ func checkpointDB( if mal, ok := malicious[catx.SmesherID]; ok && mal { continue } + var marriageAtx []byte + if catx.MarriageATX != nil { + marriageAtx = catx.MarriageATX.Bytes() + } checkpoint.Data.Atxs = append(checkpoint.Data.Atxs, types.AtxSnapshot{ ID: catx.ID.Bytes(), Epoch: catx.Epoch.Uint32(), CommitmentAtx: catx.CommitmentATX.Bytes(), + MarriageAtx: marriageAtx, VrfNonce: uint64(catx.VRFNonce), NumUnits: catx.NumUnits, BaseTickHeight: catx.BaseTickHeight, diff --git a/checkpoint/runner_test.go b/checkpoint/runner_test.go index f7009c24ec..472f62a8be 100644 --- a/checkpoint/runner_test.go +++ b/checkpoint/runner_test.go @@ -238,10 +238,15 @@ func newAtx( } func asAtxSnapshot(v *types.ActivationTx, cmt *types.ATXID) types.AtxSnapshot { + var marriageATX []byte + if v.MarriageATX != nil { + marriageATX = v.MarriageATX.Bytes() + } return types.AtxSnapshot{ ID: v.ID().Bytes(), Epoch: v.PublishEpoch.Uint32(), CommitmentAtx: cmt.Bytes(), + MarriageAtx: marriageATX, VrfNonce: uint64(v.VRFNonce), NumUnits: v.NumUnits, BaseTickHeight: v.BaseTickHeight, @@ -375,3 +380,35 @@ func TestRunner_Generate_Error(t *testing.T) { require.Error(t, err) }) } + +func TestRunner_Generate_PreservesMarriageATX(t *testing.T) { + t.Parallel() + db := sql.InMemory() + + require.NoError(t, accounts.Update(db, &types.Account{Address: types.Address{1, 1}})) + + atx := &types.ActivationTx{ + CommitmentATX: &types.ATXID{1, 2, 3, 4, 5}, + MarriageATX: &types.ATXID{6, 7, 8, 9}, + SmesherID: types.RandomNodeID(), + NumUnits: 4, + } + atx.SetID(types.RandomATXID()) + require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) + require.NoError(t, atxs.SetUnits(db, atx.ID(), atx.SmesherID, atx.NumUnits)) + + fs := afero.NewMemMapFs() + dir, err := afero.TempDir(fs, "", "Generate") + require.NoError(t, err) + + err = checkpoint.Generate(context.Background(), fs, db, dir, 5, 2) + require.NoError(t, err) + + file, err := fs.Open(checkpoint.SelfCheckpointFilename(dir, 5)) + require.NoError(t, err) + defer file.Close() + + var checkpoint types.Checkpoint + require.NoError(t, json.NewDecoder(file).Decode(&checkpoint)) + require.Equal(t, atx.MarriageATX.Bytes(), checkpoint.Data.Atxs[0].MarriageAtx) +} diff --git a/common/types/activation.go b/common/types/activation.go index 41112efc4c..4a1da34c15 100644 --- a/common/types/activation.go +++ b/common/types/activation.go @@ -178,7 +178,9 @@ type ActivationTx struct { PrevATXID ATXID // CommitmentATX is the ATX used in the commitment for initializing the PoST of the node. - CommitmentATX *ATXID + CommitmentATX *ATXID + // The marriage ATX, used in merged ATXs only. + MarriageATX *ATXID Coinbase Address NumUnits uint32 // the minimum number of space units in this and the previous ATX BaseTickHeight uint64 @@ -231,6 +233,9 @@ func (atx *ActivationTx) MarshalLogObject(encoder log.ObjectEncoder) error { if atx.CommitmentATX != nil { encoder.AddString("commitment_atx_id", atx.CommitmentATX.String()) } + if atx.MarriageATX != nil { + encoder.AddString("marriage_atx_id", atx.MarriageATX.String()) + } encoder.AddUint64("vrf_nonce", uint64(atx.VRFNonce)) encoder.AddString("coinbase", atx.Coinbase.String()) encoder.AddUint32("epoch", atx.PublishEpoch.Uint32()) diff --git a/common/types/checkpoint.go b/common/types/checkpoint.go index 7f04b35a87..81184e6b30 100644 --- a/common/types/checkpoint.go +++ b/common/types/checkpoint.go @@ -17,6 +17,7 @@ type AtxSnapshot struct { ID []byte `json:"id"` Epoch uint32 `json:"epoch"` CommitmentAtx []byte `json:"commitmentAtx"` + MarriageAtx []byte `json:"marriageAtx"` VrfNonce uint64 `json:"vrfNonce"` BaseTickHeight uint64 `json:"baseTickHeight"` TickCount uint64 `json:"tickCount"` diff --git a/go.mod b/go.mod index b8652db3cc..931e19278b 100644 --- a/go.mod +++ b/go.mod @@ -39,7 +39,7 @@ require ( github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 github.com/seehuhn/mt19937 v1.0.0 github.com/slok/go-http-metrics v0.12.0 - github.com/spacemeshos/api/release/go v1.50.0 + github.com/spacemeshos/api/release/go v1.51.0 github.com/spacemeshos/economics v0.1.3 github.com/spacemeshos/fixed v0.1.1 github.com/spacemeshos/go-scale v1.2.0 diff --git a/go.sum b/go.sum index 42c9cad502..be27867b00 100644 --- a/go.sum +++ b/go.sum @@ -602,8 +602,8 @@ github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:Udh github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= -github.com/spacemeshos/api/release/go v1.50.0 h1:M7Usg/LxymscwqYO7/Doyb+sU4lS1e+JIsSgqTDGk/0= -github.com/spacemeshos/api/release/go v1.50.0/go.mod h1:PvgDpjfwkZLVVNExYG7wDNzgMqT3p+ppfTU2UESSF9U= +github.com/spacemeshos/api/release/go v1.51.0 h1:MSKRIUiXBAoDrj2Lj24q9g52ZaSIC3I0UH/Y0Oaz95o= +github.com/spacemeshos/api/release/go v1.51.0/go.mod h1:Qr/pVPMmN5Q5qLHSXqVMDKDCu6LkHWzGPNflylE0u00= github.com/spacemeshos/economics v0.1.3 h1:ACkq3mTebIky4Zwbs9SeSSRZrUCjU/Zk0wq9Z0BTh2A= github.com/spacemeshos/economics v0.1.3/go.mod h1:FH7u0FzTIm6Kpk+X5HOZDvpkgNYBKclmH86rVwYaDAo= github.com/spacemeshos/fixed v0.1.1 h1:N1y4SUpq1EV+IdJrWJwUCt1oBFzeru/VKVcBsvPc2Fk= diff --git a/node/node.go b/node/node.go index 756fa4733f..2d9eae8d34 100644 --- a/node/node.go +++ b/node/node.go @@ -784,7 +784,7 @@ func (app *App) initServices(ctx context.Context) error { app.updater = bootstrap.New( app.clock, bootstrap.WithConfig(bscfg), - bootstrap.WithLogger(app.addLogger(BootstrapLogger, lg)), + bootstrap.WithLogger(app.addLogger(BootstrapLogger, lg).Zap()), ) if app.Config.Certificate.CommitteeSize == 0 { app.log.With().Warning("certificate committee size is not set, defaulting to hare committee size", @@ -1562,8 +1562,7 @@ func (app *App) grpcService(svc grpcserver.Service, lg log.Log) (grpcserver.Serv case v2alpha1.Network: service := v2alpha1.NewNetworkService( app.clock.GenesisTime(), - app.Config.Genesis.GenesisID(), - app.Config.LayerDuration) + app.Config) app.grpcServices[svc] = service return service, nil case v2alpha1.Node: diff --git a/sql/atxs/atxs.go b/sql/atxs/atxs.go index 408e450c75..5e14cddde1 100644 --- a/sql/atxs/atxs.go +++ b/sql/atxs/atxs.go @@ -22,7 +22,8 @@ const ( // filters that refer to the id column. const fieldsQuery = `select atxs.id, atxs.nonce, atxs.base_tick_height, atxs.tick_count, atxs.pubkey, atxs.effective_num_units, -atxs.received, atxs.epoch, atxs.sequence, atxs.coinbase, atxs.validity, atxs.prev_id, atxs.commitment_atx, atxs.weight` +atxs.received, atxs.epoch, atxs.sequence, atxs.coinbase, atxs.validity, atxs.prev_id, atxs.commitment_atx, atxs.weight, +atxs.marriage_atx` const fullQuery = fieldsQuery + ` from atxs` @@ -62,6 +63,10 @@ func decoder(fn decoderCallback) sql.Decoder { stmt.ColumnBytes(12, a.CommitmentATX[:]) } a.Weight = uint64(stmt.ColumnInt64(13)) + if stmt.ColumnType(14) != sqlite.SQLITE_NULL { + a.MarriageATX = new(types.ATXID) + stmt.ColumnBytes(14, a.MarriageATX[:]) + } return fn(&a) } @@ -425,8 +430,6 @@ func Add(db sql.Executor, atx *types.ActivationTx, blob types.AtxBlob) error { stmt.BindInt64(3, int64(atx.NumUnits)) if atx.CommitmentATX != nil { stmt.BindBytes(4, atx.CommitmentATX.Bytes()) - } else { - stmt.BindNull(4) } stmt.BindInt64(5, int64(atx.VRFNonce)) stmt.BindBytes(6, atx.SmesherID.Bytes()) @@ -438,17 +441,18 @@ func Add(db sql.Executor, atx *types.ActivationTx, blob types.AtxBlob) error { stmt.BindInt64(12, int64(atx.Validity())) if atx.PrevATXID != types.EmptyATXID { stmt.BindBytes(13, atx.PrevATXID.Bytes()) - } else { - stmt.BindNull(13) } stmt.BindInt64(14, int64(atx.Weight)) + if atx.MarriageATX != nil { + stmt.BindBytes(15, atx.MarriageATX.Bytes()) + } } _, err := db.Exec(` insert into atxs (id, epoch, effective_num_units, commitment_atx, nonce, pubkey, received, base_tick_height, tick_count, sequence, coinbase, - validity, prev_id, weight) - values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14)`, enc, nil) + validity, prev_id, weight, marriage_atx) + values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15)`, enc, nil) if err != nil { return fmt.Errorf("insert ATX ID %v: %w", atx.ID(), err) } @@ -539,6 +543,7 @@ type CheckpointAtx struct { ID types.ATXID Epoch types.EpochID CommitmentATX types.ATXID + MarriageATX *types.ATXID VRFNonce types.VRFPostIndex BaseTickHeight uint64 TickCount uint64 @@ -571,16 +576,21 @@ func LatestN(db sql.Executor, n int) ([]CheckpointAtx, error) { catx.Sequence = uint64(stmt.ColumnInt64(6)) stmt.ColumnBytes(7, catx.Coinbase[:]) catx.VRFNonce = types.VRFPostIndex(stmt.ColumnInt64(8)) - catx.Units = make(map[types.NodeID]uint32) + if stmt.ColumnType(9) != sqlite.SQLITE_NULL { + catx.MarriageATX = new(types.ATXID) + stmt.ColumnBytes(9, catx.MarriageATX[:]) + } rst = append(rst, catx) return true } rows, err := db.Exec(` - select id, epoch, effective_num_units, base_tick_height, tick_count, pubkey, sequence, coinbase, nonce + select + id, epoch, effective_num_units, base_tick_height, tick_count, pubkey, sequence, coinbase, nonce, marriage_atx from ( select row_number() over (partition by pubkey order by epoch desc) RowNum, - id, epoch, effective_num_units, base_tick_height, tick_count, pubkey, sequence, coinbase, nonce + id, epoch, effective_num_units, base_tick_height, tick_count, pubkey, sequence, coinbase, nonce, + marriage_atx from atxs ) where RowNum <= ?1 order by pubkey;`, enc, dec) @@ -616,12 +626,15 @@ func AddCheckpointed(db sql.Executor, catx *CheckpointAtx) error { stmt.BindInt64(8, int64(catx.Sequence)) stmt.BindBytes(9, catx.SmesherID.Bytes()) stmt.BindBytes(10, catx.Coinbase.Bytes()) + if catx.MarriageATX != nil { + stmt.BindBytes(11, catx.MarriageATX.Bytes()) + } } _, err := db.Exec(` insert into atxs (id, epoch, effective_num_units, commitment_atx, nonce, - base_tick_height, tick_count, sequence, pubkey, coinbase, received) - values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, 0)`, enc, nil) + base_tick_height, tick_count, sequence, pubkey, coinbase, marriage_atx, received) + values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, 0)`, enc, nil) if err != nil { return fmt.Errorf("insert checkpoint ATX %v: %w", catx.ID, err) } @@ -803,7 +816,7 @@ func IterateAtxsWithMalfeasance( func(s *sql.Statement) { s.BindInt64(1, int64(publish)) }, func(s *sql.Statement) bool { return decoder(func(atx *types.ActivationTx) bool { - return fn(atx, s.ColumnInt(14) != 0) + return fn(atx, s.ColumnInt(15) != 0) })(s) }, ) @@ -891,6 +904,38 @@ func Units(db sql.Executor, atxID types.ATXID, nodeID types.NodeID) (uint32, err return units, err } +// FindDoublePublish finds 2 distinct ATXIDs that the given identity contributed PoST to in the given epoch. +// +// It is guaranteed to return 2 distinct ATXs when the error is nil. +// It works by finding an ATX in the given epoch that has a PoST contribution from the given identity. +// - `epoch` is looked up in the `atxs` table by matching atxid. +func FindDoublePublish(db sql.Executor, nodeID types.NodeID, epoch types.EpochID) ([]types.ATXID, error) { + var ids []types.ATXID + rows, err := db.Exec(` + SELECT p.atxid + FROM posts p + INNER JOIN atxs a ON p.atxid = a.id + WHERE p.pubkey = ?1 AND a.epoch = ?2;`, + func(stmt *sql.Statement) { + stmt.BindBytes(1, nodeID.Bytes()) + stmt.BindInt64(2, int64(epoch)) + }, + func(stmt *sql.Statement) bool { + var id types.ATXID + stmt.ColumnBytes(0, id[:]) + ids = append(ids, id) + return len(ids) < 2 + }, + ) + if err != nil { + return nil, err + } + if rows != 2 { + return nil, sql.ErrNotFound + } + return ids, nil +} + func AllUnits(db sql.Executor, id types.ATXID) (map[types.NodeID]uint32, error) { units := make(map[types.NodeID]uint32) rows, err := db.Exec( @@ -963,3 +1008,28 @@ func AtxWithPrevious(db sql.Executor, prev types.ATXID, id types.NodeID) (types. } return atxid, nil } + +// Find 2 distinct merged ATXs (having the same marriage ATX) in the same epoch. +func MergeConflict(db sql.Executor, marriage types.ATXID, publish types.EpochID) ([]types.ATXID, error) { + var ids []types.ATXID + rows, err := db.Exec(` + SELECT id FROM atxs WHERE marriage_atx = ?1 and epoch = ?2;`, + func(stmt *sql.Statement) { + stmt.BindBytes(1, marriage.Bytes()) + stmt.BindInt64(2, int64(publish)) + }, + func(stmt *sql.Statement) bool { + var id types.ATXID + stmt.ColumnBytes(0, id[:]) + ids = append(ids, id) + return len(ids) < 2 + }, + ) + if err != nil { + return nil, err + } + if rows != 2 { + return nil, sql.ErrNotFound + } + return ids, nil +} diff --git a/sql/atxs/atxs_test.go b/sql/atxs/atxs_test.go index d3ac0dd3de..1dd1914968 100644 --- a/sql/atxs/atxs_test.go +++ b/sql/atxs/atxs_test.go @@ -1201,3 +1201,132 @@ func Test_AtxWithPrevious(t *testing.T) { require.Equal(t, atx2.ID(), id) }) } + +func Test_FindDoublePublish(t *testing.T) { + t.Parallel() + sig, err := signing.NewEdSigner() + require.NoError(t, err) + t.Run("no atxs", func(t *testing.T) { + t.Parallel() + db := sql.InMemory() + _, err := atxs.FindDoublePublish(db, types.RandomNodeID(), 0) + require.ErrorIs(t, err, sql.ErrNotFound) + }) + + t.Run("no double publish", func(t *testing.T) { + t.Parallel() + db := sql.InMemory() + + // one atx + atx0, blob := newAtx(t, sig, withPublishEpoch(1)) + require.NoError(t, atxs.Add(db, atx0, blob)) + require.NoError(t, atxs.SetUnits(db, atx0.ID(), atx0.SmesherID, 10)) + + _, err = atxs.FindDoublePublish(db, atx0.SmesherID, atx0.PublishEpoch) + require.ErrorIs(t, err, sql.ErrNotFound) + + // two atxs in different epochs + atx1, blob := newAtx(t, sig, withPublishEpoch(atx0.PublishEpoch+1)) + require.NoError(t, atxs.Add(db, atx1, blob)) + require.NoError(t, atxs.SetUnits(db, atx1.ID(), atx0.SmesherID, 10)) + + _, err = atxs.FindDoublePublish(db, atx0.SmesherID, atx0.PublishEpoch) + require.ErrorIs(t, err, sql.ErrNotFound) + }) + t.Run("double publish", func(t *testing.T) { + t.Parallel() + db := sql.InMemory() + + atx0, blob := newAtx(t, sig) + require.NoError(t, atxs.Add(db, atx0, blob)) + require.NoError(t, atxs.SetUnits(db, atx0.ID(), atx0.SmesherID, 10)) + + atx1, blob := newAtx(t, sig) + require.NoError(t, atxs.Add(db, atx1, blob)) + require.NoError(t, atxs.SetUnits(db, atx1.ID(), atx0.SmesherID, 10)) + + atxids, err := atxs.FindDoublePublish(db, atx0.SmesherID, atx0.PublishEpoch) + require.NoError(t, err) + require.ElementsMatch(t, []types.ATXID{atx0.ID(), atx1.ID()}, atxids) + + // filters by epoch + _, err = atxs.FindDoublePublish(db, atx0.SmesherID, atx0.PublishEpoch+1) + require.ErrorIs(t, err, sql.ErrNotFound) + }) + t.Run("double publish different smesher", func(t *testing.T) { + t.Parallel() + db := sql.InMemory() + + atx0Signer, err := signing.NewEdSigner() + require.NoError(t, err) + + atx0, blob := newAtx(t, atx0Signer) + require.NoError(t, atxs.Add(db, atx0, blob)) + require.NoError(t, atxs.SetUnits(db, atx0.ID(), atx0.SmesherID, 10)) + require.NoError(t, atxs.SetUnits(db, atx0.ID(), sig.NodeID(), 10)) + + atx1Signer, err := signing.NewEdSigner() + require.NoError(t, err) + + atx1, blob := newAtx(t, atx1Signer) + require.NoError(t, atxs.Add(db, atx1, blob)) + require.NoError(t, atxs.SetUnits(db, atx1.ID(), atx1.SmesherID, 10)) + require.NoError(t, atxs.SetUnits(db, atx1.ID(), sig.NodeID(), 10)) + + atxIDs, err := atxs.FindDoublePublish(db, sig.NodeID(), atx0.PublishEpoch) + require.NoError(t, err) + require.ElementsMatch(t, []types.ATXID{atx0.ID(), atx1.ID()}, atxIDs) + }) +} + +func Test_MergeConflict(t *testing.T) { + t.Parallel() + t.Run("no atxs", func(t *testing.T) { + t.Parallel() + db := sql.InMemory() + _, err := atxs.MergeConflict(db, types.RandomATXID(), 0) + require.ErrorIs(t, err, sql.ErrNotFound) + }) + t.Run("no conflict", func(t *testing.T) { + t.Parallel() + db := sql.InMemory() + marriage := types.RandomATXID() + + atx := types.ActivationTx{MarriageATX: &marriage} + atx.SetID(types.RandomATXID()) + require.NoError(t, atxs.Add(db, &atx, types.AtxBlob{})) + + _, err := atxs.MergeConflict(db, types.RandomATXID(), atx.PublishEpoch) + require.ErrorIs(t, err, sql.ErrNotFound) + }) + t.Run("finds conflict", func(t *testing.T) { + t.Parallel() + db := sql.InMemory() + marriage := types.RandomATXID() + + atx0 := types.ActivationTx{MarriageATX: &marriage} + atx0.SetID(types.RandomATXID()) + require.NoError(t, atxs.Add(db, &atx0, types.AtxBlob{})) + + atx1 := types.ActivationTx{MarriageATX: &marriage} + atx1.SetID(types.RandomATXID()) + require.NoError(t, atxs.Add(db, &atx1, types.AtxBlob{})) + + ids, err := atxs.MergeConflict(db, marriage, atx0.PublishEpoch) + require.NoError(t, err) + require.ElementsMatch(t, []types.ATXID{atx0.ID(), atx1.ID()}, ids) + + // filters by epoch + _, err = atxs.MergeConflict(db, types.RandomATXID(), 8) + require.ErrorIs(t, err, sql.ErrNotFound) + + // returns only 2 ATXs + atx2 := types.ActivationTx{MarriageATX: &marriage} + atx2.SetID(types.RandomATXID()) + require.NoError(t, atxs.Add(db, &atx2, types.AtxBlob{})) + + ids, err = atxs.MergeConflict(db, marriage, atx0.PublishEpoch) + require.NoError(t, err) + require.Len(t, ids, 2) + }) +} diff --git a/sql/migrations/state/0020_atx_merge.sql b/sql/migrations/state/0020_atx_merge.sql index 17bcda9c83..8dbff567c0 100644 --- a/sql/migrations/state/0020_atx_merge.sql +++ b/sql/migrations/state/0020_atx_merge.sql @@ -1,5 +1,6 @@ -- Changes required to handle merged ATXs +ALTER TABLE atxs ADD COLUMN marriage_atx CHAR(32); ALTER TABLE atxs ADD COLUMN weight INTEGER; UPDATE atxs SET weight = effective_num_units * tick_count;