Skip to content

Commit

Permalink
Fix/6041 bug (#6053)
Browse files Browse the repository at this point in the history
## Motivation

Fix the bug  



Co-authored-by: ConvallariaMaj <majalis.conv@gmail.com>
  • Loading branch information
0xBECEDA and ConvallariaMaj committed Jun 26, 2024
1 parent 66118d2 commit 2c83704
Show file tree
Hide file tree
Showing 13 changed files with 220 additions and 100 deletions.
110 changes: 76 additions & 34 deletions activation/activation.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,13 @@ var (

// PoetConfig is the configuration to interact with the poet server.
type PoetConfig struct {
PhaseShift time.Duration `mapstructure:"phase-shift"`
CycleGap time.Duration `mapstructure:"cycle-gap"`
GracePeriod time.Duration `mapstructure:"grace-period"`
RequestTimeout time.Duration `mapstructure:"poet-request-timeout"`
RequestRetryDelay time.Duration `mapstructure:"retry-delay"`
MaxRequestRetries int `mapstructure:"retry-max"`
PhaseShift time.Duration `mapstructure:"phase-shift"`
CycleGap time.Duration `mapstructure:"cycle-gap"`
GracePeriod time.Duration `mapstructure:"grace-period"`
RequestTimeout time.Duration `mapstructure:"poet-request-timeout"`
RequestRetryDelay time.Duration `mapstructure:"retry-delay"`
PositioningATXSelectionTimeout time.Duration `mapstructure:"positioning-atx-selection-timeout"`
MaxRequestRetries int `mapstructure:"retry-max"`
}

func DefaultPoetConfig() PoetConfig {
Expand All @@ -56,12 +57,6 @@ func DefaultPoetConfig() PoetConfig {

const (
defaultPoetRetryInterval = 5 * time.Second

// Jitter added to the wait time before building a nipost challenge.
// It is expressed as % of poet grace period which translates to:
// mainnet (grace period 1h) -> 36s
// systest (grace period 10s) -> 0.1s
maxNipostChallengeBuildJitter = 1.0
)

// Config defines configuration for Builder.
Expand Down Expand Up @@ -203,6 +198,7 @@ func NewBuilder(
for _, opt := range opts {
opt(b)
}

return b
}

Expand Down Expand Up @@ -547,8 +543,12 @@ func (b *Builder) BuildNIPostChallenge(ctx context.Context, nodeID types.NodeID)
until = time.Until(b.poetRoundStart(current))
}
publish := current + 1

poetStartsAt := b.poetRoundStart(current)

metrics.PublishOntimeWindowLatency.Observe(until.Seconds())
wait := buildNipostChallengeStartDeadline(b.poetRoundStart(current), b.poetCfg.GracePeriod)

wait := poetStartsAt.Add(-b.poetCfg.GracePeriod)
if time.Until(wait) > 0 {
logger.Info("paused building NiPoST challenge. Waiting until closer to poet start to get a better posATX",
zap.Duration("till poet round", until),
Expand All @@ -563,6 +563,14 @@ func (b *Builder) BuildNIPostChallenge(ctx context.Context, nodeID types.NodeID)
}
}

if b.poetCfg.PositioningATXSelectionTimeout > 0 {
var cancel context.CancelFunc

deadline := poetStartsAt.Add(-b.poetCfg.GracePeriod).Add(b.poetCfg.PositioningATXSelectionTimeout)
ctx, cancel = context.WithDeadline(ctx, deadline)
defer cancel()
}

prevAtx, err = b.GetPrevAtx(nodeID)
switch {
case errors.Is(err, sql.ErrNotFound):
Expand All @@ -585,6 +593,7 @@ func (b *Builder) BuildNIPostChallenge(ctx context.Context, nodeID types.NodeID)
}
return nil, fmt.Errorf("initial POST is invalid: %w", err)
}

posAtx, err := b.getPositioningAtx(ctx, nodeID, publish, nil)
if err != nil {
return nil, fmt.Errorf("failed to get positioning ATX: %w", err)
Expand All @@ -604,7 +613,6 @@ func (b *Builder) BuildNIPostChallenge(ctx context.Context, nodeID types.NodeID)
case err != nil:
return nil, fmt.Errorf("get last ATX: %w", err)
default:
// regular ATX challenge
posAtx, err := b.getPositioningAtx(ctx, nodeID, publish, prevAtx)
if err != nil {
return nil, fmt.Errorf("failed to get positioning ATX: %w", err)
Expand Down Expand Up @@ -851,8 +859,10 @@ func (b *Builder) searchPositioningAtx(
publish types.EpochID,
) (types.ATXID, error) {
logger := b.logger.With(log.ZShortStringer("smesherID", nodeID), zap.Uint32("publish epoch", publish.Uint32()))

b.posAtxFinder.finding.Lock()
defer b.posAtxFinder.finding.Unlock()

if found := b.posAtxFinder.found; found != nil && found.forPublish == publish {
logger.Debug("using cached positioning atx", log.ZShortStringer("atx_id", found.id))
return found.id, nil
Expand All @@ -862,7 +872,9 @@ func (b *Builder) searchPositioningAtx(
if err != nil {
return types.EmptyATXID, fmt.Errorf("get latest epoch: %w", err)
}

logger.Info("searching for positioning atx", zap.Uint32("latest_epoch", latestPublished.Uint32()))

// positioning ATX publish epoch must be lower than the publish epoch of built ATX
positioningAtxPublished := min(latestPublished, publish-1)
id, err := findFullyValidHighTickAtx(
Expand All @@ -880,6 +892,7 @@ func (b *Builder) searchPositioningAtx(
logger.Info("search failed - using golden atx as positioning atx", zap.Error(err))
id = b.conf.GoldenATXID
}

b.posAtxFinder.found = &struct {
id types.ATXID
forPublish types.EpochID
Expand All @@ -902,17 +915,39 @@ func (b *Builder) getPositioningAtx(
return types.EmptyATXID, err
}

if previous != nil {
switch {
case id == b.conf.GoldenATXID:
id = previous.ID()
case id != b.conf.GoldenATXID:
if candidate, err := atxs.Get(b.db, id); err == nil {
if previous.TickHeight() >= candidate.TickHeight() {
id = previous.ID()
}
}
}
b.logger.Info("found candidate positioning atx",
log.ZShortStringer("id", id),
log.ZShortStringer("smesherID", nodeID),
)

if previous == nil {
b.logger.Info("selected atx as positioning atx",
log.ZShortStringer("id", id),
log.ZShortStringer("smesherID", nodeID))
return id, nil
}

if id == b.conf.GoldenATXID {
id = previous.ID()
b.logger.Info("selected previous as positioning atx",
log.ZShortStringer("id", id),
log.ZShortStringer("smesherID", nodeID),
)
return id, nil
}

candidate, err := atxs.Get(b.db, id)
if err != nil {
return types.EmptyATXID, fmt.Errorf("get candidate pos ATX %s: %w", id.ShortString(), err)
}

if previous.TickHeight() >= candidate.TickHeight() {
id = previous.ID()
b.logger.Info("selected previous as positioning atx",
log.ZShortStringer("id", id),
log.ZShortStringer("smesherID", nodeID),
)
return id, nil
}

b.logger.Info("selected positioning atx", log.ZShortStringer("id", id), log.ZShortStringer("smesherID", nodeID))
Expand Down Expand Up @@ -941,11 +976,6 @@ func (b *Builder) Regossip(ctx context.Context, nodeID types.NodeID) error {
return nil
}

func buildNipostChallengeStartDeadline(roundStart time.Time, gracePeriod time.Duration) time.Time {
jitter := randomDurationInRange(time.Duration(0), gracePeriod*maxNipostChallengeBuildJitter/100.0)
return roundStart.Add(jitter).Add(-gracePeriod)
}

func (b *Builder) version(publish types.EpochID) types.AtxVersion {
version := types.AtxV1
for _, v := range b.versions {
Expand All @@ -966,8 +996,15 @@ func findFullyValidHighTickAtx(
opts ...VerifyChainOption,
) (types.ATXID, error) {
var found *types.ATXID
atxdata.IterateHighTicksInEpoch(publish+1, func(id types.ATXID) bool {

// iterate trough epochs, to get first valid, not malicious ATX with the biggest height
atxdata.IterateHighTicksInEpoch(publish+1, func(id types.ATXID) (contSearch bool) {
logger.Info("found candidate for high-tick atx", log.ZShortStringer("id", id))
if ctx.Err() != nil {
return false
}
// verify ATX-candidate by getting their dependencies (previous Atx, positioning ATX etc.)
// and verifying PoST for every dependency
if err := validator.VerifyChain(ctx, id, goldenATXID, opts...); err != nil {
logger.Info("rejecting candidate for high-tick atx", zap.Error(err), log.ZShortStringer("id", id))
return true
Expand All @@ -976,8 +1013,13 @@ func findFullyValidHighTickAtx(
return false
})

if found != nil {
return *found, nil
if ctx.Err() != nil {
return types.ATXID{}, ctx.Err()
}

if found == nil {
return types.ATXID{}, ErrNotFound
}
return types.ATXID{}, ErrNotFound

return *found, nil
}
73 changes: 40 additions & 33 deletions activation/activation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1292,10 +1292,13 @@ func TestWaitPositioningAtx(t *testing.T) {
tab.mnipost.EXPECT().
BuildNIPost(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(&nipost.NIPostState{}, nil)

closed := make(chan struct{})
close(closed)

tab.mclock.EXPECT().AwaitLayer(types.EpochID(1).FirstLayer()).Return(closed).AnyTimes()
tab.mclock.EXPECT().AwaitLayer(types.EpochID(2).FirstLayer()).Return(closed).AnyTimes()

tab.mpub.EXPECT().Publish(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, _ string, got []byte) error {
var atx wire.ActivationTxV1
Expand Down Expand Up @@ -1328,39 +1331,6 @@ func TestWaitPositioningAtx(t *testing.T) {
}
}

func TestWaitingToBuildNipostChallengeWithJitter(t *testing.T) {
t.Run("before grace period", func(t *testing.T) {
// ┌──grace period──┐
// │ │
// ───▲─────|──────|─────────|----> time
// │ └jitter| └round start
// now
deadline := buildNipostChallengeStartDeadline(time.Now().Add(2*time.Hour), time.Hour)
require.Greater(t, deadline, time.Now().Add(time.Hour))
require.LessOrEqual(t, deadline, time.Now().Add(time.Hour+time.Second*36))
})
t.Run("after grace period, within max jitter value", func(t *testing.T) {
// ┌──grace period──┐
// │ │
// ─────────|──▲────|────────|----> time
// └ji│tter| └round start
// now
deadline := buildNipostChallengeStartDeadline(time.Now().Add(time.Hour-time.Second*10), time.Hour)
require.GreaterOrEqual(t, deadline, time.Now().Add(-time.Second*10))
// jitter is 1% = 36s for 1h grace period
require.LessOrEqual(t, deadline, time.Now().Add(time.Second*(36-10)))
})
t.Run("after jitter max value", func(t *testing.T) {
// ┌──grace period──┐
// │ │
// ─────────|──────|──▲──────|----> time
// └jitter| │ └round start
// now
deadline := buildNipostChallengeStartDeadline(time.Now().Add(time.Hour-time.Second*37), time.Hour)
require.Less(t, deadline, time.Now())
})
}

// Test if GetPositioningAtx disregards ATXs with invalid POST in their chain.
// It should pick an ATX with valid POST even though it's a lower height.
func TestGetPositioningAtxPicksAtxWithValidChain(t *testing.T) {
Expand Down Expand Up @@ -1471,6 +1441,43 @@ func TestGetPositioningAtx(t *testing.T) {
require.NoError(t, err)
require.Equal(t, prev.ID(), selected)
})
t.Run("prefers own previous or golded when positioning ATX selection timout expired", func(t *testing.T) {
tab := newTestBuilder(t, 1)

atxInDb := &types.ActivationTx{TickCount: 100}
atxInDb.SetID(types.RandomATXID())
require.NoError(t, atxs.Add(tab.db, atxInDb))
tab.atxsdata.AddFromAtx(atxInDb, false)

prev := &types.ActivationTx{TickCount: 90}
prev.SetID(types.RandomATXID())

// no timeout set up
tab.mValidator.EXPECT().VerifyChain(gomock.Any(), atxInDb.ID(), tab.goldenATXID, gomock.Any())
found, err := tab.getPositioningAtx(context.Background(), types.EmptyNodeID, 99, prev)
require.NoError(t, err)
require.Equal(t, atxInDb.ID(), found)

tab.posAtxFinder.found = nil

// timeout set up, prev ATX exists
ctx, cancel := context.WithCancel(context.Background())
cancel()

selected, err := tab.getPositioningAtx(ctx, types.EmptyNodeID, 99, prev)
require.NoError(t, err)
require.Equal(t, prev.ID(), selected)

tab.posAtxFinder.found = nil

// timeout set up, prev ATX do not exists
ctx, cancel = context.WithCancel(context.Background())
cancel()

selected, err = tab.getPositioningAtx(ctx, types.EmptyNodeID, 99, nil)
require.NoError(t, err)
require.Equal(t, tab.goldenATXID, selected)
})
}

func TestFindFullyValidHighTickAtx(t *testing.T) {
Expand Down
5 changes: 4 additions & 1 deletion activation/e2e/certifier_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,10 @@ func (c *testCertifier) certify(w http.ResponseWriter, r *http.Request) {
NumUnits: req.Metadata.NumUnits,
LabelsPerUnit: c.cfg.LabelsPerUnit,
}
if err := c.postVerifier.Verify(context.Background(), proof, metadata, c.opts...); err != nil {
if err := c.postVerifier.Verify(
context.Background(),
proof, metadata,
activation.WithVerifierOptions(c.opts...)); err != nil {
http.Error(w, fmt.Sprintf("verifying POST: %v", err), http.StatusBadRequest)
return
}
Expand Down
29 changes: 28 additions & 1 deletion activation/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,40 @@ type AtxReceiver interface {

type PostVerifier interface {
io.Closer
Verify(ctx context.Context, p *shared.Proof, m *shared.ProofMetadata, opts ...verifying.OptionFunc) error
Verify(ctx context.Context, p *shared.Proof, m *shared.ProofMetadata, opts ...postVerifierOptionFunc) error
}

type scaler interface {
scale(int)
}

type postVerifierCallOption struct {
prioritized bool
verifierOptions []verifying.OptionFunc
}

type postVerifierOptionFunc func(*postVerifierCallOption)

func applyOptions(options ...postVerifierOptionFunc) postVerifierCallOption {
opts := postVerifierCallOption{}
for _, opt := range options {
opt(&opts)
}
return opts
}

func PrioritizedCall() postVerifierOptionFunc {
return func(o *postVerifierCallOption) {
o.prioritized = true
}
}

func WithVerifierOptions(ops ...verifying.OptionFunc) postVerifierOptionFunc {
return func(o *postVerifierCallOption) {
o.verifierOptions = ops
}
}

// validatorOption is a functional option type for the validator.
type validatorOption func(*validatorOptions)

Expand Down
2 changes: 1 addition & 1 deletion activation/malfeasance.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func (mh *InvalidPostIndexHandler) Validate(ctx context.Context, data wire.Proof
ctx,
post,
meta,
verifying.SelectedIndex(int(proof.InvalidIdx)),
WithVerifierOptions(verifying.SelectedIndex(int(proof.InvalidIdx))),
); err != nil {
return atx.SmesherID, nil
}
Expand Down
Loading

0 comments on commit 2c83704

Please sign in to comment.