diff --git a/eth/api_simulation.go b/eth/api_simulation.go index b73048aaec24..d86b56594d4d 100644 --- a/eth/api_simulation.go +++ b/eth/api_simulation.go @@ -5,13 +5,15 @@ import ( "encoding/json" "errors" "fmt" - "github.com/ethereum/go-ethereum/core/txpool" + "github.com/ethereum/go-ethereum/common" "math" "strings" "sync" "sync/atomic" "time" + "github.com/ethereum/go-ethereum/core/txpool" + "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/state" @@ -27,15 +29,23 @@ type TraceInternalTransactionArgs struct { Tx hexutil.Bytes `json:"tx"` } -type TransactionInternalTransactionsByBundleArgs struct { - Txs []hexutil.Bytes `json:"txs"` -} - type Backend interface { BlockChain() *core.BlockChain TxPool() *txpool.TxPool } +// list is a "list" of the statedb belonging to an account, sorted by account nonce +type list struct { + snapshots map[uint64]*state.StateDB + mu sync.Mutex +} + +func (l *list) findSnapshotByNonce(nonce uint64) *state.StateDB { + l.mu.Lock() + defer l.mu.Unlock() + return l.snapshots[nonce] +} + // SimulationAPIBackend creates a new simulation API type SimulationAPIBackend struct { eth Backend @@ -53,6 +63,8 @@ type SimulationAPIBackend struct { stateDb *state.StateDB // current stateDb of the blockchain currentSigner types.Signer // current signer according to the current block currentBlockCtx vm.BlockContext // current block context according to the current block + + stateDbCheckpoint sync.Map // store "list" of checkpoint belonging to an account address } func NewSimulationAPI(eth Backend) *SimulationAPIBackend { @@ -103,42 +115,6 @@ func (b *SimulationAPIBackend) TraceInternalTransaction(ctx context.Context, arg return simulationResponse, nil } -func (b *SimulationAPIBackend) TraceInternalTransactionsByBundle(ctx context.Context, args TransactionInternalTransactionsByBundleArgs) ([]*types.SimulationTxResponse, error) { - if len(args.Txs) == 0 { - return nil, errors.New("missing transaction") - } - - if isCatchUpLatestBlock := b.isCatchUpLatestBlock.Load(); !isCatchUpLatestBlock { - var blockNumber uint64 - if b.currentBlock != nil { - blockNumber = b.currentBlock.NumberU64() - } - return nil, fmt.Errorf("the state isn't up to date, block_number: %d", blockNumber) - } - - var ( - currentBlock = b.currentBlock - stateDb = b.stateDb - simulationBundleResponse = make([]*types.SimulationTxResponse, 0) - ) - - for _, binaryTx := range args.Txs { - tx := new(types.Transaction) - if err := tx.UnmarshalBinary(binaryTx); err != nil { - return nil, err - } - - simulationResponse, err := b.simulate(tx, stateDb.Copy(), currentBlock) - if err != nil { - return nil, err - } - - simulationBundleResponse = append(simulationBundleResponse, simulationResponse) - } - - return simulationBundleResponse, nil -} - func (b *SimulationAPIBackend) Stop() { b.chainHeadSub.Unsubscribe() close(b.exitCh) @@ -174,6 +150,10 @@ func (b *SimulationAPIBackend) loop() error { b.currentBlockCtx = blockCtx b.currentSigner = signer b.isCatchUpLatestBlock.Store(true) + + // clear the checkpoint of snapshots if the states are stale + b.clearStaleSnapshots(readOnlyStateDb.Copy()) + case err := <-b.chainHeadSub.Err(): return err case <-b.exitCh: @@ -213,6 +193,18 @@ func (b *SimulationAPIBackend) simulate(tx *types.Transaction, stateDb *state.St TxHash: tx.Hash(), } ) + // load the checkpoint db if exists + var currentList *list + if enc, found := b.stateDbCheckpoint.Load(msg.From.Hex()); found && enc != nil { + list, ok := enc.(*list) + if ok && list != nil { + currentList = list + checkpointStateDb := list.findSnapshotByNonce(msg.Nonce) + if checkpointStateDb != nil { + stateDb = checkpointStateDb.Copy() + } + } + } internalTransactionTracer, err := tracers.DefaultDirectory.New(native.InternalTransactionTracerName, tracerCtx, json.RawMessage{}) if err != nil { @@ -232,6 +224,7 @@ func (b *SimulationAPIBackend) simulate(tx *types.Transaction, stateDb *state.St log.Error("Failed to apply the message", "hash", tx.Hash().String(), "number", currentBlock.NumberU64(), "err", err) return nil, err } + b.storeSnapshot(stateDb, currentList, msg.Nonce, msg.From) if executionResult == nil { log.Warn("Simulation result is empty", "tx_hash", tx.Hash().String()) @@ -311,3 +304,41 @@ func (b *SimulationAPIBackend) isLatestBlock(blockTime int64) bool { } return false } + +func (b *SimulationAPIBackend) storeSnapshot(stateDb *state.StateDB, l *list, nonce uint64, from common.Address) { + if l == nil { + l = &list{ + snapshots: make(map[uint64]*state.StateDB), + } + } + nextNonce := nonce + 1 + l.snapshots[nextNonce] = stateDb + b.stateDbCheckpoint.Store(from.Hex(), l) +} + +func (b *SimulationAPIBackend) clearStaleSnapshots(stateDb *state.StateDB) { + b.stateDbCheckpoint.Range(func(k, v any) bool { + l := v.(*list) + + address := k.(string) + + if l == nil { + b.stateDbCheckpoint.Delete(address) + return true + } + + pendingNonce := stateDb.GetNonce(common.HexToAddress(address)) + + for nonce := range l.snapshots { + if pendingNonce > nonce-1 { + delete(l.snapshots, nonce) + } + } + + if len(l.snapshots) == 0 { + b.stateDbCheckpoint.Delete(address) + return true + } + return true + }) +}