Skip to content

Commit

Permalink
verify atx syntactic correctness before fetching deps
Browse files Browse the repository at this point in the history
  • Loading branch information
countvonzero committed Sep 5, 2023
1 parent 61e3c63 commit 210f2ec
Show file tree
Hide file tree
Showing 11 changed files with 333 additions and 384 deletions.
2 changes: 1 addition & 1 deletion activation/activation.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ func (b *Builder) verifyInitialPost(ctx context.Context, post *types.Post, metad
if err != nil {
b.log.With().Panic("failed to fetch commitment ATX ID.", log.Err(err))
}
err = b.validator.Post(ctx, types.EpochID(0), b.nodeID, commitmentAtxId, post, metadata, b.postSetupProvider.LastOpts().NumUnits)
err = b.validator.Post(ctx, b.nodeID, commitmentAtxId, post, metadata, b.postSetupProvider.LastOpts().NumUnits)
switch {
case errors.Is(err, context.Canceled):
// If the context was canceled, we don't want to emit or log errors just propagate the cancellation signal.
Expand Down
184 changes: 90 additions & 94 deletions activation/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ type Handler struct {
edVerifier *signing.EdVerifier
clock layerClock
publisher pubsub.Publisher
layersPerEpoch uint32
tickSize uint64
goldenATXID types.ATXID
nipostValidator nipostValidator
Expand All @@ -63,7 +62,6 @@ func NewHandler(
c layerClock,
pub pubsub.Publisher,
fetcher system.Fetcher,
layersPerEpoch uint32,
tickSize uint64,
goldenATXID types.ATXID,
nipostValidator nipostValidator,
Expand All @@ -77,7 +75,6 @@ func NewHandler(
edVerifier: edVerifier,
clock: c,
publisher: pub,
layersPerEpoch: layersPerEpoch,
tickSize: tickSize,
goldenATXID: goldenATXID,
nipostValidator: nipostValidator,
Expand Down Expand Up @@ -165,47 +162,84 @@ func (h *Handler) ProcessAtx(ctx context.Context, atx *types.VerifiedActivationT
return nil
}

// SyntacticallyValidateAtx ensures the following conditions apply, otherwise it returns an error.
//
// - The PublishEpoch is less than or equal to the current epoch + 1.
// - If the sequence number is non-zero: PrevATX points to a syntactically valid ATX whose sequence number is one less
// than the current ATXs sequence number.
// - If the sequence number is zero: PrevATX is empty.
// - Positioning ATX points to a syntactically valid ATX.
// - NIPost challenge is a hash of the serialization of the following fields:
// NodeID, SequenceNumber, PrevATXID, LayerID, StartTick, PositioningATX.
// - The NIPost is valid.
// - ATX LayerID is NIPostLayerTime or less after the PositioningATX LayerID.
// - The ATX view of the previous epoch contains ActiveSetSize activations.
func (h *Handler) SyntacticallyValidateAtx(ctx context.Context, atx *types.ActivationTx) (*types.VerifiedActivationTx, error) {
func (h *Handler) SyntacticallyValidate(ctx context.Context, atx *types.ActivationTx) error {
if atx.NIPost == nil {
return fmt.Errorf("nil nipst for atx %s", atx.ShortString())
}
current := h.clock.CurrentLayer().GetEpoch()
if atx.PublishEpoch > current+1 {
return fmt.Errorf("atx publish epoch is too far in the future: %d > %d", atx.PublishEpoch, current+1)
}
if atx.PositioningATX == types.EmptyATXID {
return fmt.Errorf("empty positioning atx")
}

switch {
case atx.PrevATXID == types.EmptyATXID:
if atx.InitialPost == nil {
return fmt.Errorf("no prev atx declared, but initial post is not included")
}
if atx.InnerActivationTx.NodeID == nil {
return fmt.Errorf("no prev atx declared, but node id is missing")
}
if atx.VRFNonce == nil {
return fmt.Errorf("no prev atx declared, but vrf nonce is missing")
}
if atx.CommitmentATX == nil {
return fmt.Errorf("no prev atx declared, but commitment atx is missing")
}
if *atx.CommitmentATX == types.EmptyATXID {
return fmt.Errorf("empty commitment atx")
}
if atx.Sequence != 0 {
return fmt.Errorf("no prev atx declared, but sequence number not zero")
}

// Use the NIPost's Post metadata, while overriding the challenge to a zero challenge,
// as expected from the initial Post.
initialPostMetadata := *atx.NIPost.PostMetadata
initialPostMetadata.Challenge = shared.ZeroChallenge
if err := h.nipostValidator.Post(ctx, atx.SmesherID, *atx.CommitmentATX, atx.InitialPost, &initialPostMetadata, atx.NumUnits); err != nil {
return fmt.Errorf("invalid initial post: %w", err)
}
if err := h.nipostValidator.VRFNonce(atx.SmesherID, *atx.CommitmentATX, atx.VRFNonce, &initialPostMetadata, atx.NumUnits); err != nil {
return fmt.Errorf("invalid vrf nonce: %w", err)
}
default:
if atx.InnerActivationTx.NodeID != nil {
return fmt.Errorf("prev atx declared, but node id is included")
}
if atx.InitialPost != nil {
return fmt.Errorf("prev atx declared, but initial post is included")
}
if atx.CommitmentATX != nil {
return fmt.Errorf("rpev atx declared, but commitment atx is included")
}
}
return nil
}

func (h *Handler) SyntacticallyValidateDeps(ctx context.Context, atx *types.ActivationTx) (*types.VerifiedActivationTx, error) {
var (
commitmentATX *types.ATXID
err error
)

current := h.clock.CurrentLayer()
if atx.PublishEpoch > current.GetEpoch()+1 {
return nil, fmt.Errorf("atx publish epoch is too far in the future: %d > %d", atx.PublishEpoch, current.GetEpoch()+1)
}

if atx.PrevATXID == types.EmptyATXID {
if err := h.validateInitialAtx(ctx, atx); err != nil {
return nil, err
}
commitmentATX = atx.CommitmentATX // validateInitialAtx checks that commitmentATX is not nil and references an existing valid ATX
commitmentATX = atx.CommitmentATX
} else {
commitmentATX, err = h.getCommitmentAtx(atx)
if err != nil {
return nil, fmt.Errorf("commitment atx for %s not found: %w", atx.SmesherID, err)
}

err = h.validateNonInitialAtx(ctx, atx, *commitmentATX)
if err != nil {
if err := h.validateNonInitialAtx(ctx, atx, *commitmentATX); err != nil {
return nil, err
}
}

if err := h.nipostValidator.PositioningAtx(&atx.PositioningATX, h.cdb, h.goldenATXID, atx.PublishEpoch, h.layersPerEpoch); err != nil {
if err := h.nipostValidator.PositioningAtx(&atx.PositioningATX, h.cdb, h.goldenATXID, atx.PublishEpoch); err != nil {
return nil, err
}

Expand All @@ -218,7 +252,7 @@ func (h *Handler) SyntacticallyValidateAtx(ctx context.Context, atx *types.Activ
expectedChallengeHash := atx.NIPostChallenge.Hash()
h.log.WithContext(ctx).With().Info("validating nipost", log.String("expected_challenge_hash", expectedChallengeHash.String()), atx.ID())

leaves, err := h.nipostValidator.NIPost(ctx, atx.PublishEpoch, atx.SmesherID, *commitmentATX, atx.NIPost, expectedChallengeHash, atx.NumUnits)
leaves, err := h.nipostValidator.NIPost(ctx, atx.SmesherID, *commitmentATX, atx.NIPost, expectedChallengeHash, atx.NumUnits)
if err != nil {
return nil, fmt.Errorf("invalid nipost: %w", err)
}
Expand All @@ -227,44 +261,14 @@ func (h *Handler) SyntacticallyValidateAtx(ctx context.Context, atx *types.Activ
}

func (h *Handler) validateInitialAtx(ctx context.Context, atx *types.ActivationTx) error {
if atx.InitialPost == nil {
return fmt.Errorf("no prevATX declared, but initial Post is not included")
}

if atx.InnerActivationTx.NodeID == nil {
return fmt.Errorf("no prevATX declared, but NodeID is missing")
}

if err := h.nipostValidator.InitialNIPostChallenge(&atx.NIPostChallenge, h.cdb, h.goldenATXID); err != nil {
return err
}

// Use the NIPost's Post metadata, while overriding the challenge to a zero challenge,
// as expected from the initial Post.
initialPostMetadata := *atx.NIPost.PostMetadata
initialPostMetadata.Challenge = shared.ZeroChallenge

if err := h.nipostValidator.Post(ctx, atx.PublishEpoch, atx.SmesherID, *atx.CommitmentATX, atx.InitialPost, &initialPostMetadata, atx.NumUnits); err != nil {
return fmt.Errorf("invalid initial Post: %w", err)
}

if atx.VRFNonce == nil {
return fmt.Errorf("no prevATX declared, but VRFNonce is missing")
}

if err := h.nipostValidator.VRFNonce(atx.SmesherID, *atx.CommitmentATX, atx.VRFNonce, &initialPostMetadata, atx.NumUnits); err != nil {
return fmt.Errorf("invalid VRFNonce: %w", err)
}

atx.SetEffectiveNumUnits(atx.NumUnits)
return nil
}

func (h *Handler) validateNonInitialAtx(ctx context.Context, atx *types.ActivationTx, commitmentATX types.ATXID) error {
if atx.InnerActivationTx.NodeID != nil {
return fmt.Errorf("prevATX declared, but NodeID is included")
}

if err := h.nipostValidator.NIPostChallenge(&atx.NIPostChallenge, h.cdb, atx.SmesherID); err != nil {
return err
}
Expand All @@ -276,7 +280,7 @@ func (h *Handler) validateNonInitialAtx(ctx context.Context, atx *types.Activati

nonce := atx.VRFNonce
if atx.NumUnits > prevAtx.NumUnits && nonce == nil {
h.log.WithContext(ctx).With().Info("PoST size increased without new VRF Nonce, re-validating current nonce",
h.log.WithContext(ctx).With().Info("post size increased without new vrf Nonce, re-validating current nonce",
atx.ID(),
log.Stringer("smesher", atx.SmesherID),
)
Expand All @@ -291,14 +295,10 @@ func (h *Handler) validateNonInitialAtx(ctx context.Context, atx *types.Activati
if nonce != nil {
err = h.nipostValidator.VRFNonce(atx.SmesherID, commitmentATX, nonce, atx.NIPost.PostMetadata, atx.NumUnits)
if err != nil {
return fmt.Errorf("invalid VRFNonce: %w", err)
return fmt.Errorf("invalid vrf nonce: %w", err)
}
}

if atx.InitialPost != nil {
return fmt.Errorf("prevATX declared, but initial Post is included")
}

if prevAtx.NumUnits < atx.NumUnits {
atx.SetEffectiveNumUnits(prevAtx.NumUnits)
} else {
Expand Down Expand Up @@ -330,7 +330,7 @@ func (h *Handler) ContextuallyValidateAtx(atx *types.VerifiedActivationTx) error

if err == nil && atx.PrevATXID == types.EmptyATXID {
// no previous atx declared, but already seen at least one atx from node
return fmt.Errorf("no prevATX reported, but other atx with same nodeID (%v) found: %v", atx.SmesherID, lastAtx.ShortString())
return fmt.Errorf("no prev atx reported, but other atx with same node id (%v) found: %v", atx.SmesherID, lastAtx.ShortString())
}

if err == nil && atx.PrevATXID != lastAtx {
Expand Down Expand Up @@ -391,7 +391,7 @@ func (h *Handler) storeAtx(ctx context.Context, atx *types.VerifiedActivationTx)
}
encoded, err := codec.Encode(proof)
if err != nil {
h.log.With().Panic("failed to encode MalfeasanceProof", log.Err(err))
h.log.With().Panic("failed to encode malfeasance proof", log.Err(err))
}
if err := identities.SetMalicious(dbtx, atx.SmesherID, encoded, time.Now()); err != nil {
return fmt.Errorf("add malfeasance proof: %w", err)
Expand Down Expand Up @@ -438,7 +438,7 @@ func (h *Handler) storeAtx(ctx context.Context, atx *types.VerifiedActivationTx)
}
encodedProof, err := codec.Encode(&gossip)
if err != nil {
h.log.With().Fatal("failed to encode MalfeasanceGossip", log.Err(err))
h.log.With().Fatal("failed to encode malfeasance gossip", log.Err(err))
}
if err = h.publisher.Publish(ctx, pubsub.MalfeasanceProof, encodedProof); err != nil {
h.log.With().Error("failed to broadcast malfeasance proof", log.Err(err))
Expand Down Expand Up @@ -517,63 +517,54 @@ func (h *Handler) handleAtx(ctx context.Context, expHash types.Hash32, peer p2p.
return fmt.Errorf("failed to verify atx signature: %w", errMalformedData)
}

logger := h.log.WithContext(ctx).WithFields(atx.ID())
existing, _ := h.cdb.GetAtxHeader(atx.ID())
if existing != nil {
logger.With().Debug("received known atx")
return fmt.Errorf("%w atx %s", errKnownAtx, atx.ID())
}

if atx.NIPost == nil {
return fmt.Errorf("nil nipst in gossip for atx %s", atx.ShortString())
if err := h.SyntacticallyValidate(ctx, &atx); err != nil {
return err
}

h.registerHashes(&atx, peer)
if err := h.fetcher.GetPoetProof(ctx, atx.GetPoetProofRef()); err != nil {
return fmt.Errorf("received atx (%v) with syntactically invalid or missing poet proof (%x): %w",
atx.ShortString(), atx.GetPoetProofRef().ShortString(), err,
)
}

if err := h.FetchAtxReferences(ctx, &atx); err != nil {
return fmt.Errorf("received atx (%v) with missing references of prev or pos id %v, %v: %w",
atx.ID().ShortString(), atx.PrevATXID.ShortString(), atx.PositioningATX.ShortString(), err,
)
if err := h.FetchReferences(ctx, &atx); err != nil {
return err
}

vAtx, err := h.SyntacticallyValidateAtx(ctx, &atx)
vAtx, err := h.SyntacticallyValidateDeps(ctx, &atx)
if err != nil {
return fmt.Errorf("received syntactically invalid atx %v: %w", atx.ShortString(), err)
return fmt.Errorf("atx %v syntatically invalid based on deps: %w", atx.ShortString(), err)
}

if expHash != (types.Hash32{}) && vAtx.ID().Hash32() != expHash {
return fmt.Errorf("%w: atx want %s, got %s", errWrongHash, expHash.ShortString(), vAtx.ID().Hash32().ShortString())
}

err = h.ProcessAtx(ctx, vAtx)
if err != nil {
if err := h.ProcessAtx(ctx, vAtx); err != nil {
return fmt.Errorf("cannot process atx %v: %w", atx.ShortString(), err)
}
events.ReportNewActivation(vAtx)
logger.With().Info("new atx", log.Inline(vAtx), log.Int("size", len(msg)))
h.log.WithContext(ctx).With().Info("new atx", log.Inline(vAtx), log.Int("size", len(msg)))
return nil
}

// FetchAtxReferences fetches referenced ATXs from peers if they are not found in db.
func (h *Handler) FetchAtxReferences(ctx context.Context, atx *types.ActivationTx) error {
logger := h.log.WithContext(ctx)
// FetchReferences fetches referenced ATXs from peers if they are not found in db.
func (h *Handler) FetchReferences(ctx context.Context, atx *types.ActivationTx) error {
if err := h.fetcher.GetPoetProof(ctx, atx.GetPoetProofRef()); err != nil {
return fmt.Errorf("atx (%s) missing poet proof (%s): %w",
atx.ShortString(), atx.GetPoetProofRef().ShortString(), err,
)
}

atxIDs := make(map[types.ATXID]struct{}, 3)
if atx.PositioningATX != types.EmptyATXID && atx.PositioningATX != h.goldenATXID {
logger.With().Debug("fetching pos atx", atx.PositioningATX, atx.ID())
atxIDs[atx.PositioningATX] = struct{}{}
}

if atx.PrevATXID != types.EmptyATXID {
logger.With().Debug("fetching prev atx", atx.PrevATXID, atx.ID())
atxIDs[atx.PrevATXID] = struct{}{}
}
if atx.CommitmentATX != nil && *atx.CommitmentATX != h.goldenATXID {
logger.With().Debug("fetching commitment atx", *atx.CommitmentATX, atx.ID())
atxIDs[*atx.CommitmentATX] = struct{}{}
}

Expand All @@ -582,8 +573,13 @@ func (h *Handler) FetchAtxReferences(ctx context.Context, atx *types.ActivationT
}

if err := h.fetcher.GetAtxs(ctx, maps.Keys(atxIDs)); err != nil {
return fmt.Errorf("fetch referenced atxs: %w", err)
dbg := fmt.Sprintf("prev %v pos %v commit %v", atx.PrevATXID, atx.PositioningATX, atx.CommitmentATX)
return fmt.Errorf("fetch referenced atxs (%s): %w", dbg, err)
}
logger.With().Debug("done fetching references for atx", atx.ID())

h.log.WithContext(ctx).With().Debug("done fetching references for atx",
atx.ID(),
log.Int("num fetched", len(atxIDs)),
)
return nil
}
Loading

0 comments on commit 210f2ec

Please sign in to comment.