Skip to content

Commit

Permalink
Optimize searching for positioning ATX (#5952)
Browse files Browse the repository at this point in the history
## Motivation

Searching for positioning ATX is slow because:
- the SQL query is slow
- it usually happens at the same time as many ATXs are being inserted into the DB (the poet CG)
  • Loading branch information
poszu committed May 21, 2024
1 parent 5e6551a commit d5bb1b4
Show file tree
Hide file tree
Showing 14 changed files with 345 additions and 134 deletions.
132 changes: 83 additions & 49 deletions activation/activation.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (

"github.com/spacemeshos/go-spacemesh/activation/metrics"
"github.com/spacemeshos/go-spacemesh/activation/wire"
"github.com/spacemeshos/go-spacemesh/atxsdata"
"github.com/spacemeshos/go-spacemesh/codec"
"github.com/spacemeshos/go-spacemesh/common/types"
"github.com/spacemeshos/go-spacemesh/events"
Expand Down Expand Up @@ -72,6 +73,7 @@ type Builder struct {
coinbaseAccount types.Address
conf Config
db sql.Executor
atxsdata *atxsdata.Data
localDB *localsql.Database
publisher pubsub.Publisher
nipostBuilder nipostBuilder
Expand Down Expand Up @@ -152,6 +154,7 @@ func WithPostStates(ps PostStates) BuilderOption {
func NewBuilder(
conf Config,
db sql.Executor,
atxsdata *atxsdata.Data,
localDB *localsql.Database,
publisher pubsub.Publisher,
nipostBuilder nipostBuilder,
Expand All @@ -165,6 +168,7 @@ func NewBuilder(
signers: make(map[types.NodeID]*signing.EdSigner),
conf: conf,
db: db,
atxsdata: atxsdata,
localDB: localDB,
publisher: publisher,
nipostBuilder: nipostBuilder,
Expand Down Expand Up @@ -507,12 +511,6 @@ func (b *Builder) BuildNIPostChallenge(ctx context.Context, nodeID types.NodeID)
}
}

posAtx, err := b.getPositioningAtx(ctx, nodeID, publish)
if err != nil {
return nil, fmt.Errorf("failed to get positioning ATX: %w", err)
}
logger.Info("selected positioning atx", log.ZShortStringer("atx_id", posAtx))

prevAtx, err = b.GetPrevAtx(nodeID)
switch {
case errors.Is(err, sql.ErrNotFound):
Expand All @@ -538,6 +536,10 @@ func (b *Builder) BuildNIPostChallenge(ctx context.Context, nodeID types.NodeID)
}
return nil, fmt.Errorf("initial POST is invalid: %w", err)
}
posAtx, err := b.getPositioningAtx(ctx, nodeID, publish, nil)
if err != nil {
return nil, fmt.Errorf("failed to get positioning ATX: %w", err)
}
challenge = &types.NIPostChallenge{
PublishEpoch: publish,
Sequence: 0,
Expand All @@ -554,6 +556,10 @@ func (b *Builder) BuildNIPostChallenge(ctx context.Context, nodeID types.NodeID)
return nil, fmt.Errorf("get last ATX: %w", err)
default:
// regular ATX challenge
posAtx, err := b.getPositioningAtx(ctx, nodeID, publish, prevAtx)
if err != nil {
return nil, fmt.Errorf("failed to get positioning ATX: %w", err)
}
challenge = &types.NIPostChallenge{
PublishEpoch: publish,
Sequence: prevAtx.Sequence + 1,
Expand Down Expand Up @@ -692,6 +698,12 @@ func (b *Builder) createAtx(
break
}
if nipostState.VRFNonce != oldNonce {
b.log.Info(
"attaching a new VRF nonce in ATX",
log.ZShortStringer("smesherID", sig.NodeID()),
zap.Uint64("new nonce", uint64(nipostState.VRFNonce)),
zap.Uint64("old nonce", uint64(oldNonce)),
)
nonce = &nipostState.VRFNonce
}
}
Expand Down Expand Up @@ -723,9 +735,9 @@ func (b *Builder) broadcast(ctx context.Context, atx scale.Encodable) (int, erro
return len(buf), nil
}

// getPositioningAtx returns atx id with the highest tick height.
// searchPositioningAtx returns atx id with the highest tick height.
// publish epoch is used for caching the positioning atx.
func (b *Builder) getPositioningAtx(
func (b *Builder) searchPositioningAtx(
ctx context.Context,
nodeID types.NodeID,
publish types.EpochID,
Expand All @@ -738,34 +750,65 @@ func (b *Builder) getPositioningAtx(
return found.id, nil
}

logger.Info("searching for positioning atx")
latestPublished, err := atxs.LatestEpoch(b.db)
if err != nil {
return types.EmptyATXID, fmt.Errorf("get latest epoch: %w", err)
}
logger.Info("searching for positioning atx", zap.Uint32("latest_epoch", latestPublished.Uint32()))
// positioning ATX publish epoch must be lower than the publish epoch of built ATX
positioningAtxPublished := min(latestPublished, publish-1)
id, err := findFullyValidHighTickAtx(
ctx,
b.db,
nodeID,
b.atxsdata,
positioningAtxPublished,
b.conf.GoldenATXID,
b.validator,
logger,
VerifyChainOpts.AssumeValidBefore(time.Now().Add(-b.postValidityDelay)),
VerifyChainOpts.WithTrustedID(nodeID),
VerifyChainOpts.WithLogger(b.log),
)
switch {
case err == nil:
b.posAtxFinder.found = &struct {
id types.ATXID
forPublish types.EpochID
}{id, publish}
return id, nil
case errors.Is(err, sql.ErrNotFound):
logger.Info("using golden atx as positioning atx")
b.posAtxFinder.found = &struct {
id types.ATXID
forPublish types.EpochID
}{b.conf.GoldenATXID, publish}
return b.conf.GoldenATXID, nil
if err != nil {
logger.Info("search failed - using golden atx as positioning atx", zap.Error(err))
id = b.conf.GoldenATXID
}
b.posAtxFinder.found = &struct {
id types.ATXID
forPublish types.EpochID
}{id, publish}

return id, nil
}

// getPositioningAtx returns the positioning ATX.
// The provided previous ATX is picked if it has a greater or equal
// tick count as the ATX selected in `searchPositioningAtx`.
func (b *Builder) getPositioningAtx(
ctx context.Context,
nodeID types.NodeID,
publish types.EpochID,
previous *types.ActivationTx,
) (types.ATXID, error) {
id, err := b.searchPositioningAtx(ctx, nodeID, publish)
if err != nil {
return types.EmptyATXID, err
}

if previous != nil {
switch {
case id == b.conf.GoldenATXID:
id = previous.ID()
case id != b.conf.GoldenATXID:
if candidate, err := atxs.Get(b.db, id); err == nil {
if previous.TickHeight() >= candidate.TickHeight() {
id = previous.ID()
}
}
}
}
return id, err

b.log.Info("selected positioning atx", log.ZShortStringer("id", id), log.ZShortStringer("smesherID", nodeID))
return id, nil
}

func (b *Builder) Regossip(ctx context.Context, nodeID types.NodeID) error {
Expand Down Expand Up @@ -797,35 +840,26 @@ func buildNipostChallengeStartDeadline(roundStart time.Time, gracePeriod time.Du

func findFullyValidHighTickAtx(
ctx context.Context,
db sql.Executor,
prefNodeID types.NodeID,
atxdata *atxsdata.Data,
publish types.EpochID,
goldenATXID types.ATXID,
validator nipostValidator,
logger *zap.Logger,
opts ...VerifyChainOption,
) (types.ATXID, error) {
rejectedAtxs := make(map[types.ATXID]struct{})
filter := func(id types.ATXID) bool {
_, ok := rejectedAtxs[id]
return !ok
}

for {
select {
case <-ctx.Done():
return types.ATXID{}, ctx.Err()
default:
}
id, err := atxs.GetIDWithMaxHeight(db, prefNodeID, filter)
if err != nil {
return types.ATXID{}, err
}
logger.Info("found candidate for high-tick atx, verifying its chain", log.ZShortStringer("atx_id", id))
var found *types.ATXID
atxdata.IterateHighTicksInEpoch(publish+1, func(id types.ATXID) bool {
logger.Info("found candidate for high-tick atx", log.ZShortStringer("id", id))
if err := validator.VerifyChain(ctx, id, goldenATXID, opts...); err != nil {
logger.Info("rejecting candidate for high-tick atx", zap.Error(err), log.ZShortStringer("atx_id", id))
rejectedAtxs[id] = struct{}{}
} else {
return id, nil
logger.Info("rejecting candidate for high-tick atx", zap.Error(err), log.ZShortStringer("id", id))
return true
}
found = &id
return false
})

if found != nil {
return *found, nil
}
return types.ATXID{}, ErrNotFound
}
Loading

0 comments on commit d5bb1b4

Please sign in to comment.