From e7983ab4d69eb3b4a0e630b10e94b89b0a75fd39 Mon Sep 17 00:00:00 2001 From: blxdyx Date: Wed, 4 Sep 2024 15:48:40 +0800 Subject: [PATCH] Refactor the reset snapshots --- consensus/consensus.go | 1 + consensus/parlia/parlia.go | 78 +++++++++++++++++++++---------- consensus/parlia/snapshot.go | 30 +++++++++++- eth/stagedsync/stage_snapshots.go | 11 ++++- 4 files changed, 91 insertions(+), 29 deletions(-) diff --git a/consensus/consensus.go b/consensus/consensus.go index 25674ff312a..3bc43a2cbcc 100644 --- a/consensus/consensus.go +++ b/consensus/consensus.go @@ -226,6 +226,7 @@ type PoSA interface { GetJustifiedNumberAndHash(chain ChainHeaderReader, header *types.Header) (uint64, libcommon.Hash, error) GetFinalizedHeader(chain ChainHeaderReader, header *types.Header) *types.Header ResetSnapshot(chain ChainHeaderReader, header *types.Header) error + GetBscProgress() (uint64, error) } type AsyncEngine interface { diff --git a/consensus/parlia/parlia.go b/consensus/parlia/parlia.go index ce6c31fa25f..16ba7e16cd0 100644 --- a/consensus/parlia/parlia.go +++ b/consensus/parlia/parlia.go @@ -1618,6 +1618,10 @@ func (p *Parlia) blockTimeVerifyForRamanujanFork(snap *Snapshot, header, parent return nil } +func (p *Parlia) GetBscProgress() (uint64, error) { + return getLatest(p.db) +} + // ResetSnapshot Fill consensus db from snapshot func (p *Parlia) ResetSnapshot(chain consensus.ChainHeaderReader, header *types.Header) error { // Search for a snapshot in memory or on disk for checkpoints @@ -1628,36 +1632,60 @@ func (p *Parlia) ResetSnapshot(chain consensus.ChainHeaderReader, header *types. hash := header.Hash() number := header.Number.Uint64() - // If we're at the genesis, snapshot the initial state. - if number == 0 { - // Headers included into the snapshots have to be trusted as checkpoints get validators from headers - validators, voteAddrs, err := parseValidators(header, p.chainConfig, p.config) - if err != nil { - return err - } - // new snapshot - snap = newSnapshot(p.config, p.signatures, number, hash, validators, voteAddrs) - p.recentSnaps.Add(hash, snap) - if err := snap.store(p.db); err != nil { - return err - } - p.logger.Info("Stored checkpoint snapshot to disk", "number", number, "hash", hash) - } else { - snap, ok := p.recentSnaps.Get(header.ParentHash) - if !ok { - return fmt.Errorf("can't found parent Snap, number = %d", number) + for snap == nil { + if s, ok := p.recentSnaps.Get(hash); ok { + snap = s + break } - headers = append(headers, header) - if _, err := snap.apply(headers, chain, nil, p.chainConfig, p.recentSnaps); err != nil { - return err + + // If an on-disk checkpoint snapshot can be found, use that + if number%CheckpointInterval == 0 { + if s, err := loadSnapshot(p.config, p.signatures, p.db, number, hash); err == nil { + log.Trace("Loaded snapshot from disk", "number", number, "hash", hash) + snap = s + break + } } - // If we've generated a new checkpoint snapshot, save to disk - if snap.Number%CheckpointInterval == 0 { + + // If we're at the genesis, snapshot the initial state. + if number == 0 { + // Headers included into the snapshots have to be trusted as checkpoints get validators from headers + validators, voteAddrs, err := parseValidators(header, p.chainConfig, p.config) + if err != nil { + return err + } + // new snapshot + snap = newSnapshot(p.config, p.signatures, number, hash, validators, voteAddrs) + p.recentSnaps.Add(hash, snap) if err := snap.store(p.db); err != nil { return err } - p.logger.Trace("Stored snapshot to disk", "number", snap.Number, "hash", snap.Hash) + p.logger.Info("Stored checkpoint snapshot to disk", "number", number, "hash", hash) + break } + headers = append(headers, header) + number, hash = number-1, header.ParentHash + header = chain.GetHeader(hash, number) } - return nil + + // check if snapshot is nil + if snap == nil { + return fmt.Errorf("unknown error while retrieving snapshot at block number %v", number) + } + + // Previous snapshot found, apply any pending headers on top of it + for i := 0; i < len(headers)/2; i++ { + headers[i], headers[len(headers)-1-i] = headers[len(headers)-1-i], headers[i] + } + snap, err := snap.apply(headers, chain, nil, p.chainConfig, p.recentSnaps) + if err != nil { + return err + } + if snap.Number%CheckpointInterval == 0 && len(headers) > 0 { + if err = snap.store(p.db); err != nil { + return err + } + log.Trace("Stored snapshot to disk", "number", snap.Number, "hash", snap.Hash) + } + return err } diff --git a/consensus/parlia/snapshot.go b/consensus/parlia/snapshot.go index d10c38093d2..5892e2e0889 100644 --- a/consensus/parlia/snapshot.go +++ b/consensus/parlia/snapshot.go @@ -19,6 +19,7 @@ package parlia import ( "bytes" "context" + "encoding/binary" "encoding/hex" "encoding/json" "errors" @@ -58,6 +59,8 @@ type ValidatorInfo struct { VoteAddress types.BLSPublicKey `json:"vote_address,omitempty"` } +const lastSnapshot = "snap" + // newSnapshot creates a new snapshot with the specified startup parameters. This // method does not initialize the set of recent validators, so only ever use it for // the genesis block. @@ -113,7 +116,7 @@ func SnapshotFullKey(number uint64, hash common.Hash) []byte { return append(hexutility.EncodeTs(number), hash.Bytes()...) } -var ErrNoSnapsnot = fmt.Errorf("no parlia snapshot") +var ErrNoSnapsnot = errors.New("no parlia snapshot") // loadSnapshot loads an existing snapshot from the database. func loadSnapshot(config *chain.ParliaConfig, sigCache *lru.ARCCache[common.Hash, common.Address], db kv.RwDB, num uint64, hash common.Hash) (*Snapshot, error) { @@ -144,6 +147,23 @@ func loadSnapshot(config *chain.ParliaConfig, sigCache *lru.ARCCache[common.Hash return snap, nil } +// getLatest getLatest snapshots number +func getLatest(db kv.RwDB) (uint64, error) { + tx, err := db.BeginRo(context.Background()) + if err != nil { + return 0, err + } + defer tx.Rollback() + data, err := tx.GetOne(kv.ParliaSnapshot, []byte(lastSnapshot)) + if err != nil { + return 0, err + } + if len(data) == 0 { + return 0, nil + } + return binary.BigEndian.Uint64(data[:8]), nil +} + // store inserts the snapshot into the database. func (s *Snapshot) store(db kv.RwDB) error { blob, err := json.Marshal(s) @@ -151,7 +171,13 @@ func (s *Snapshot) store(db kv.RwDB) error { return err } return db.UpdateNosync(context.Background(), func(tx kv.RwTx) error { - return tx.Put(kv.ParliaSnapshot, SnapshotFullKey(s.Number, s.Hash), blob) + if err = tx.Put(kv.ParliaSnapshot, SnapshotFullKey(s.Number, s.Hash), blob); err != nil { + return err + } + if err = tx.Put(kv.ParliaSnapshot, []byte(lastSnapshot), hexutility.EncodeTs(s.Number)); err != nil { + return err + } + return err }) } diff --git a/eth/stagedsync/stage_snapshots.go b/eth/stagedsync/stage_snapshots.go index d0dfc2c57d8..9df84eaff25 100644 --- a/eth/stagedsync/stage_snapshots.go +++ b/eth/stagedsync/stage_snapshots.go @@ -385,6 +385,7 @@ func FillDBFromSnapshots(logPrefix string, ctx context.Context, tx kv.RwTx, dirs td := big.NewInt(0) blockNumBytes := make([]byte, 8) posa, isPoSa := engine.(consensus.PoSA) + var bscProgress uint64 chainReader := &ChainReaderImpl{config: &chainConfig, tx: tx, blockReader: blockReader} if err := blockReader.HeadersRange(ctx, func(header *types.Header) error { blockNum, blockHash := header.Number.Uint64(), header.Hash() @@ -416,10 +417,16 @@ func FillDBFromSnapshots(logPrefix string, ctx context.Context, tx kv.RwTx, dirs } } if isPoSa { - // Fill bsc consensus snapshots may have some conditions for validators snapshots - if err := posa.ResetSnapshot(chainReader, header); err != nil { + bscProgress, err = posa.GetBscProgress() + if err == nil { return err } + if blockNum > bscProgress { + // Fill bsc consensus snapshots may have some conditions for validators snapshots + if err := posa.ResetSnapshot(chainReader, header); err != nil { + return err + } + } } select { case <-ctx.Done():