Skip to content

Commit

Permalink
use singleflight to deduplicate ATX processing (#6106)
Browse files Browse the repository at this point in the history
## Motivation

The `singleflight` pkg allows us to simplify the code.
  • Loading branch information
poszu committed Jul 9, 2024
1 parent e716a01 commit 0e1bd2c
Showing 1 changed file with 18 additions and 52 deletions.
70 changes: 18 additions & 52 deletions activation/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ import (
"errors"
"fmt"
"slices"
"sync"
"time"

"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"golang.org/x/sync/singleflight"

"github.com/spacemeshos/go-spacemesh/activation/wire"
"github.com/spacemeshos/go-spacemesh/atxsdata"
Expand Down Expand Up @@ -72,10 +72,8 @@ type Handler struct {
logger *zap.Logger
versions []atxVersion

// inProgress map gathers ATXs that are currently being processed.
// It's used to avoid processing the same ATX twice.
inProgress map[types.ATXID][]chan error
inProgressMu sync.Mutex
// inProgress is used to avoid processing the same ATX multiple times in parallel.
inProgress singleflight.Group

v1 *HandlerV1
v2 *HandlerV2
Expand Down Expand Up @@ -149,8 +147,6 @@ func NewHandler(
beacon: beacon,
tortoise: tortoise,
},

inProgress: make(map[types.ATXID][]chan error),
}

for _, opt := range opts {
Expand Down Expand Up @@ -277,56 +273,26 @@ func (h *Handler) handleAtx(
return nil, fmt.Errorf("%w: atx want %s, got %s", errWrongHash, expHash.ShortString(), id.ShortString())
}

// Check if processing is already in progress
h.inProgressMu.Lock()
if sub, ok := h.inProgress[id]; ok {
ch := make(chan error, 1)
h.inProgress[id] = append(sub, ch)
h.inProgressMu.Unlock()
h.logger.Debug("atx is already being processed. waiting for result",
key := string(id.Bytes())
proof, err, _ := h.inProgress.Do(key, func() (any, error) {
h.logger.Debug("handling incoming atx",
log.ZContext(ctx),
zap.Stringer("atx_id", id),
zap.Int("size", len(msg)),
)
select {
case err := <-ch:
h.logger.Debug("atx processed in other task",
log.ZContext(ctx),
zap.Stringer("atx_id", id),
zap.Error(err),
)
return nil, err
case <-ctx.Done():
return nil, ctx.Err()
}
}

h.inProgress[id] = []chan error{}
h.inProgressMu.Unlock()
h.logger.Info("handling incoming atx",
log.ZContext(ctx),
zap.Stringer("atx_id", id),
zap.Int("size", len(msg)),
)

var proof *mwire.MalfeasanceProof

switch atx := opaqueAtx.(type) {
case *wire.ActivationTxV1:
proof, err = h.v1.processATX(ctx, peer, atx, receivedTime)
case *wire.ActivationTxV2:
proof, err = h.v2.processATX(ctx, peer, atx, receivedTime)
default:
panic("unreachable")
}
switch atx := opaqueAtx.(type) {
case *wire.ActivationTxV1:
return h.v1.processATX(ctx, peer, atx, receivedTime)
case *wire.ActivationTxV2:
return h.v2.processATX(ctx, peer, atx, receivedTime)
default:
panic("unreachable")
}
})
h.inProgress.Forget(key)

h.inProgressMu.Lock()
defer h.inProgressMu.Unlock()
for _, ch := range h.inProgress[id] {
ch <- err
close(ch)
}
delete(h.inProgress, id)
return proof, err
return proof.(*mwire.MalfeasanceProof), err
}

// Obtain the atxSignature of the given ATX.
Expand Down

0 comments on commit 0e1bd2c

Please sign in to comment.