diff --git a/CHANGELOG.md b/CHANGELOG.md index 68551a1759..35dc65c0d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ See [RELEASE](./RELEASE.md) for workflow instructions. * [#5930](https://github.com/spacemeshos/go-spacemesh/pull/5930) Check if identity for a given malfeasance proof exists when validating it. +* [#5923](https://github.com/spacemeshos/go-spacemesh/pull/5923) Fix high memory consumption and performance issues + in the proposal handler. + ## Release v1.5.2-hotfix1 This release includes our first CVE fix. A vulnerability was found in the way a node handles incoming ATXs. We urge all diff --git a/proposals/handler.go b/proposals/handler.go index 956751c1b7..ac8e741a4c 100644 --- a/proposals/handler.go +++ b/proposals/handler.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "sync" "time" lru "github.com/hashicorp/golang-lru/v2" @@ -48,16 +49,18 @@ type Handler struct { logger log.Log cfg Config - db *sql.Database - atxsdata *atxsdata.Data - activeSets *lru.Cache[types.Hash32, uint64] - edVerifier *signing.EdVerifier - publisher pubsub.Publisher - fetcher system.Fetcher - mesh meshProvider - validator eligibilityValidator - tortoise tortoiseProvider - clock layerClock + db *sql.Database + atxsdata *atxsdata.Data + activeSets *lru.Cache[types.Hash32, uint64] + edVerifier *signing.EdVerifier + publisher pubsub.Publisher + fetcher system.Fetcher + mesh meshProvider + validator eligibilityValidator + tortoise tortoiseProvider + weightCalcLock sync.Mutex + pendingWeightCalc map[types.Hash32][]chan uint64 + clock layerClock proposals proposalsConsumer } @@ -123,18 +126,19 @@ func NewHandler( panic(err) } b := &Handler{ - logger: log.NewNop(), - cfg: defaultConfig(), - db: db, - atxsdata: atxsdata, - proposals: proposals, - activeSets: activeSets, - edVerifier: edVerifier, - publisher: p, - fetcher: f, - mesh: m, - tortoise: tortoise, - clock: clock, + logger: log.NewNop(), + cfg: defaultConfig(), + db: db, + atxsdata: atxsdata, + proposals: proposals, + activeSets: activeSets, + edVerifier: edVerifier, + publisher: p, + fetcher: f, + mesh: m, + tortoise: tortoise, + pendingWeightCalc: make(map[types.Hash32][]chan uint64), + clock: clock, } for _, opt := range opts { opt(b) @@ -519,6 +523,87 @@ func (h *Handler) checkBallotSyntacticValidity( return decoded, nil } +func (h *Handler) getActiveSetWeight(ctx context.Context, id types.Hash32) (uint64, error) { + h.weightCalcLock.Lock() + totalWeight, exists := h.activeSets.Get(id) + if exists { + h.weightCalcLock.Unlock() + return totalWeight, nil + } + + var ch chan uint64 + chs, exists := h.pendingWeightCalc[id] + if exists { + // The calculation is running or the activeset is being fetched, + // subscribe. + // Avoid any blocking on the channel by making it buffered, also so that + // we don't have to wait on it in case the context is canceled + ch = make(chan uint64, 1) + h.pendingWeightCalc[id] = append(chs, ch) + h.weightCalcLock.Unlock() + + // need to wait for the calculation which is already running to finish + select { + case <-ctx.Done(): + return 0, ctx.Err() + case totalWeight, ok := <-ch: + if !ok { + // Channel closed, fetch / calculation failed. + // The actual error will be logged by the initiator of the + // initial fetch / calculation, let's not make an + // impression it happened multiple times and use a simpler + // message + return totalWeight, errors.New("error getting activeset weight") + } + return totalWeight, nil + } + } + + // mark calculation as running + h.pendingWeightCalc[id] = nil + h.weightCalcLock.Unlock() + + success := false + defer func() { + h.weightCalcLock.Lock() + // this is guaranteed not to block b/c each channel is buffered + for _, ch := range h.pendingWeightCalc[id] { + if success { + ch <- totalWeight + } + close(ch) + } + delete(h.pendingWeightCalc, id) + h.weightCalcLock.Unlock() + }() + + if err := h.fetcher.GetActiveSet(ctx, id); err != nil { + return 0, err + } + set, err := activesets.Get(h.db, id) + if err != nil { + return 0, err + } + if len(set.Set) == 0 { + return 0, fmt.Errorf("%w: empty active set", pubsub.ErrValidationReject) + } + + computed, used := h.atxsdata.WeightForSet(set.Epoch, set.Set) + for i := range used { + if !used[i] { + return 0, fmt.Errorf( + "missing atx %s in active set", + set.Set[i].ShortString(), + ) + } + } + totalWeight = computed + h.activeSets.Add(id, totalWeight) + success = true // totalWeight will be sent to the subscribers + + return totalWeight, nil +} + func (h *Handler) checkBallotDataIntegrity(ctx context.Context, b *types.Ballot) (uint64, error) { //nolint:nestif if b.RefBallot == types.EmptyBallotID { @@ -534,36 +619,9 @@ func (h *Handler) checkBallotDataIntegrity(ctx context.Context, b *types.Ballot) epoch-- // download activesets in the previous epoch too } if b.Layer.GetEpoch() >= epoch { - var exists bool - totalWeight, exists := h.activeSets.Get(b.EpochData.ActiveSetHash) - if !exists { - if err := h.fetcher.GetActiveSet(ctx, b.EpochData.ActiveSetHash); err != nil { - return 0, err - } - set, err := activesets.Get(h.db, b.EpochData.ActiveSetHash) - if err != nil { - return 0, err - } - if len(set.Set) == 0 { - return 0, fmt.Errorf( - "%w: empty active set ballot %s", - pubsub.ErrValidationReject, - b.ID().String(), - ) - } - - computed, used := h.atxsdata.WeightForSet(set.Epoch, set.Set) - for i := range used { - if !used[i] { - return 0, fmt.Errorf( - "missing atx %s in active set ballot %s", - set.Set[i].ShortString(), - b.ID().String(), - ) - } - } - totalWeight = computed - h.activeSets.Add(b.EpochData.ActiveSetHash, totalWeight) + totalWeight, err := h.getActiveSetWeight(ctx, b.EpochData.ActiveSetHash) + if err != nil { + return 0, fmt.Errorf("ballot %s: %w", b.ID().String(), err) } return totalWeight, nil } diff --git a/proposals/handler_test.go b/proposals/handler_test.go index 7875d50b41..8b704d6b5f 100644 --- a/proposals/handler_test.go +++ b/proposals/handler_test.go @@ -13,6 +13,7 @@ import ( "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" + "golang.org/x/sync/errgroup" "github.com/spacemeshos/go-spacemesh/atxsdata" "github.com/spacemeshos/go-spacemesh/codec" @@ -1392,9 +1393,9 @@ func TestHandleActiveSet(t *testing.T) { } } -func gproposal(signer *signing.EdSigner, atxid types.ATXID, +func gproposal(t *testing.T, signer *signing.EdSigner, atxid types.ATXID, layer types.LayerID, edata *types.EpochData, -) types.Proposal { +) *types.Proposal { p := types.Proposal{} p.Layer = layer p.AtxID = atxid @@ -1406,39 +1407,50 @@ func gproposal(signer *signing.EdSigner, atxid types.ATXID, if edata != nil { p.SetBeacon(edata.Beacon) } - return p + require.NoError(t, p.Initialize()) + return &p } -func TestHandleSyncedProposalActiveSet(t *testing.T) { +type asTestHandler struct { + *testing.T + *testHandler + lid types.LayerID + set types.ATXIDList + p []*types.Proposal + pid p2p.Peer + startCh chan struct{} + contCh chan error +} + +func createASTestHandler(t *testing.T) *asTestHandler { signer, err := signing.NewEdSigner() require.NoError(t, err) - set := types.ATXIDList{{1}, {2}} - lid := types.LayerID(20) - good := gproposal(signer, types.ATXID{1}, lid, &types.EpochData{ - ActiveSetHash: set.Hash(), - Beacon: types.Beacon{1}, - }) - require.NoError(t, good.Initialize()) - th := createTestHandler(t) - pid := p2p.Peer("any") - - th.mclock.EXPECT().CurrentLayer().Return(lid).AnyTimes() - th.mm.EXPECT().ProcessedLayer().Return(lid - 2).AnyTimes() - th.mclock.EXPECT().LayerToTime(gomock.Any()) - th.mf.EXPECT().RegisterPeerHashes(pid, gomock.Any()).AnyTimes() - th.mf.EXPECT().GetActiveSet(gomock.Any(), set.Hash()).DoAndReturn( - func(_ context.Context, got types.Hash32) error { - require.NoError(t, activesets.Add(th.db, got, &types.EpochActiveSet{ - Epoch: lid.GetEpoch(), - Set: set, - })) - for _, id := range set { - th.atxsdata.AddAtx(lid.GetEpoch(), id, &atxsdata.ATX{Node: types.NodeID{1}}) - } - return nil - }, - ) + th := &asTestHandler{ + T: t, + testHandler: createTestHandler(t), + lid: types.LayerID(20), + set: types.ATXIDList{{1}, {2}, {3}}, + pid: p2p.Peer("any"), + startCh: make(chan struct{}), + contCh: make(chan error), + } + th.p = []*types.Proposal{ + gproposal(t, signer, types.ATXID{1}, th.lid, &types.EpochData{ + ActiveSetHash: th.set.Hash(), + Beacon: types.Beacon{1}, + }), + gproposal(t, signer, types.ATXID{2}, th.lid, &types.EpochData{ + ActiveSetHash: th.set.Hash(), + Beacon: types.Beacon{1}, + }), + } + + th.mclock.EXPECT().CurrentLayer().Return(th.lid).AnyTimes() + th.mm.EXPECT().ProcessedLayer().Return(th.lid - 2).AnyTimes() + th.mclock.EXPECT().LayerToTime(gomock.Any()).AnyTimes() + th.mf.EXPECT().RegisterPeerHashes(th.pid, gomock.Any()).AnyTimes() + th.mf.EXPECT().GetAtxs(gomock.Any(), gomock.Any()).AnyTimes() th.mf.EXPECT().GetBallots(gomock.Any(), gomock.Any()).AnyTimes() th.mockSet.decodeAnyBallots() @@ -1446,10 +1458,152 @@ func TestHandleSyncedProposalActiveSet(t *testing.T) { th.mm.EXPECT().AddBallot(gomock.Any(), gomock.Any()).AnyTimes() th.mm.EXPECT().AddTXsFromProposal(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - th.mconsumer.EXPECT().IsKnown(good.Layer, good.ID()) - th.mconsumer.EXPECT().OnProposal(gomock.Eq(&good)) - err = th.HandleSyncedProposal(context.Background(), good.ID().AsHash32(), pid, codec.MustEncode(&good)) - require.NoError(t, err) + return th +} + +func (th *asTestHandler) expectIsKnown(n int) { + th.mconsumer.EXPECT().IsKnown(th.p[n].Layer, th.p[n].ID()) +} + +func (th *asTestHandler) expectProposal(n int) { + th.expectIsKnown(n) + th.mconsumer.EXPECT().OnProposal(gomock.Eq(th.p[n])) +} + +func (th *asTestHandler) blockOnGetActiveSet(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + case th.startCh <- struct{}{}: + } + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-th.contCh: + return err + } +} + +func (th *asTestHandler) waitForFetchToStart() { + <-th.startCh +} + +func (th *asTestHandler) continueFetching(err error) { + th.contCh <- err +} + +func (th *asTestHandler) expectGetActiveSet(block bool) { + th.mf.EXPECT().GetActiveSet(gomock.Any(), th.set.Hash()).DoAndReturn( + func(ctx context.Context, got types.Hash32) error { + if block { + if err := th.blockOnGetActiveSet(ctx); err != nil { + return err + } + } + require.NoError(th, activesets.Add(th.db, got, &types.EpochActiveSet{ + Epoch: th.lid.GetEpoch(), + Set: th.set, + })) + for _, id := range th.set { + th.atxsdata.AddAtx(th.lid.GetEpoch(), id, &atxsdata.ATX{Node: types.NodeID{1}}) + } + return nil + }, + ) +} + +func (th *asTestHandler) handleSyncedProposal(ctx context.Context, n int) error { + return th.HandleSyncedProposal( + ctx, th.p[n].ID().AsHash32(), th.pid, codec.MustEncode(th.p[n])) +} + +func (th *asTestHandler) waitForSubscription() { + require.Eventually(th, func() bool { + th.weightCalcLock.Lock() + defer th.weightCalcLock.Unlock() + return len(th.pendingWeightCalc[th.set.Hash()]) != 0 + }, 10*time.Second, 10*time.Millisecond) +} + +func TestHandleSyncedProposalActiveSet(t *testing.T) { + ctx := context.Background() + + t.Run("non-concurrent fetch", func(t *testing.T) { + th := createASTestHandler(t) + th.expectProposal(0) + th.expectGetActiveSet(false) + require.NoError(t, th.handleSyncedProposal(ctx, 0)) + + th.expectProposal(1) + // ActiveSet not fetched again here + require.NoError(t, th.handleSyncedProposal(ctx, 1)) + }) + + t.Run("concurrent fetch", func(t *testing.T) { + th := createASTestHandler(t) + th.expectProposal(0) + th.expectGetActiveSet(true) + var eg errgroup.Group + eg.Go(func() error { return th.handleSyncedProposal(ctx, 0) }) + th.waitForFetchToStart() + th.expectProposal(1) + eg.Go(func() error { return th.handleSyncedProposal(ctx, 1) }) + th.waitForSubscription() + th.continueFetching(nil) + require.NoError(t, eg.Wait()) + }) + + t.Run("fetch failure and refetch", func(t *testing.T) { + th := createASTestHandler(t) + th.expectIsKnown(0) + th.expectGetActiveSet(true) + var eg errgroup.Group + eg.Go(func() error { + require.Error(t, th.handleSyncedProposal(ctx, 0)) + return nil + }) + th.waitForFetchToStart() + th.expectIsKnown(1) + eg.Go(func() error { + require.Error(t, th.handleSyncedProposal(ctx, 1)) + return nil + }) + th.waitForSubscription() + th.continueFetching(errors.New("fail")) + require.NoError(t, eg.Wait()) + + // refetch + th.expectProposal(0) + th.expectGetActiveSet(false) + require.NoError(t, th.handleSyncedProposal(ctx, 0)) + }) + + t.Run("cancel fetch and refetch", func(t *testing.T) { + ctx, cancel := context.WithCancel(ctx) + th := createASTestHandler(t) + th.expectIsKnown(0) + th.expectGetActiveSet(true) + var eg errgroup.Group + eg.Go(func() error { + require.ErrorIs(t, th.handleSyncedProposal(ctx, 0), context.Canceled) + return nil + }) + th.waitForFetchToStart() + th.expectIsKnown(1) + eg.Go(func() error { + require.ErrorIs(t, th.handleSyncedProposal(ctx, 1), context.Canceled) + return nil + }) + th.waitForSubscription() + cancel() + require.NoError(t, eg.Wait()) + + // refetch + th.expectProposal(0) + th.expectGetActiveSet(false) + require.NoError(t, th.handleSyncedProposal(ctx, 0)) + }) } func TestHandler_SettingBallotBeacon(t *testing.T) {