From 0e1bd2c509d91f80923b5415fc70921804744ab7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bartosz=20R=C3=B3=C5=BCa=C5=84ski?= Date: Tue, 9 Jul 2024 14:33:41 +0000 Subject: [PATCH] use singleflight to deduplicate ATX processing (#6106) ## Motivation The `singleflight` pkg allows us to simplify the code. --- activation/handler.go | 70 +++++++++++-------------------------------- 1 file changed, 18 insertions(+), 52 deletions(-) diff --git a/activation/handler.go b/activation/handler.go index c828b45391..8290d7883c 100644 --- a/activation/handler.go +++ b/activation/handler.go @@ -5,11 +5,11 @@ import ( "errors" "fmt" "slices" - "sync" "time" "go.uber.org/zap" "go.uber.org/zap/zapcore" + "golang.org/x/sync/singleflight" "github.com/spacemeshos/go-spacemesh/activation/wire" "github.com/spacemeshos/go-spacemesh/atxsdata" @@ -72,10 +72,8 @@ type Handler struct { logger *zap.Logger versions []atxVersion - // inProgress map gathers ATXs that are currently being processed. - // It's used to avoid processing the same ATX twice. - inProgress map[types.ATXID][]chan error - inProgressMu sync.Mutex + // inProgress is used to avoid processing the same ATX multiple times in parallel. + inProgress singleflight.Group v1 *HandlerV1 v2 *HandlerV2 @@ -149,8 +147,6 @@ func NewHandler( beacon: beacon, tortoise: tortoise, }, - - inProgress: make(map[types.ATXID][]chan error), } for _, opt := range opts { @@ -277,56 +273,26 @@ func (h *Handler) handleAtx( return nil, fmt.Errorf("%w: atx want %s, got %s", errWrongHash, expHash.ShortString(), id.ShortString()) } - // Check if processing is already in progress - h.inProgressMu.Lock() - if sub, ok := h.inProgress[id]; ok { - ch := make(chan error, 1) - h.inProgress[id] = append(sub, ch) - h.inProgressMu.Unlock() - h.logger.Debug("atx is already being processed. waiting for result", + key := string(id.Bytes()) + proof, err, _ := h.inProgress.Do(key, func() (any, error) { + h.logger.Debug("handling incoming atx", log.ZContext(ctx), zap.Stringer("atx_id", id), + zap.Int("size", len(msg)), ) - select { - case err := <-ch: - h.logger.Debug("atx processed in other task", - log.ZContext(ctx), - zap.Stringer("atx_id", id), - zap.Error(err), - ) - return nil, err - case <-ctx.Done(): - return nil, ctx.Err() - } - } - h.inProgress[id] = []chan error{} - h.inProgressMu.Unlock() - h.logger.Info("handling incoming atx", - log.ZContext(ctx), - zap.Stringer("atx_id", id), - zap.Int("size", len(msg)), - ) - - var proof *mwire.MalfeasanceProof - - switch atx := opaqueAtx.(type) { - case *wire.ActivationTxV1: - proof, err = h.v1.processATX(ctx, peer, atx, receivedTime) - case *wire.ActivationTxV2: - proof, err = h.v2.processATX(ctx, peer, atx, receivedTime) - default: - panic("unreachable") - } + switch atx := opaqueAtx.(type) { + case *wire.ActivationTxV1: + return h.v1.processATX(ctx, peer, atx, receivedTime) + case *wire.ActivationTxV2: + return h.v2.processATX(ctx, peer, atx, receivedTime) + default: + panic("unreachable") + } + }) + h.inProgress.Forget(key) - h.inProgressMu.Lock() - defer h.inProgressMu.Unlock() - for _, ch := range h.inProgress[id] { - ch <- err - close(ch) - } - delete(h.inProgress, id) - return proof, err + return proof.(*mwire.MalfeasanceProof), err } // Obtain the atxSignature of the given ATX.