From 64e52cf9388e179fe12a2892b8284be3e9ce7647 Mon Sep 17 00:00:00 2001 From: Jerry Date: Thu, 29 Aug 2024 08:56:35 -0700 Subject: [PATCH 1/2] zkevm_getProof (#1014) * feat: adding an implementation for eth_getProof This is just a proof of concept. This might not even make sense and probably don't belong in this particular file, but wanted to see if we could expose an implementation of [eip-1186](https://eips.ethereum.org/EIPS/eip-1186) for the SMT + poseidon. * SMT Proof + verification * Refactor and tests * Use finish stage to get latest block The latest block should be validated with a state root hash check before becoming the "latest" block. --------- Co-authored-by: John Hilliard Co-authored-by: Valentin Staykov <79150443+V-Staykov@users.noreply.github.com> --- cmd/rpcdaemon/commands/zkevm_api.go | 191 ++++++++++++++++++++++ core/types/accounts/account_proof.go | 19 +++ smt/pkg/smt/proof.go | 207 +++++++++++++++++++++++ smt/pkg/smt/proof_test.go | 236 +++++++++++++++++++++++++++ turbo/rpchelper/rpc_block.go | 2 +- 5 files changed, 654 insertions(+), 1 deletion(-) create mode 100644 smt/pkg/smt/proof.go create mode 100644 smt/pkg/smt/proof_test.go diff --git a/cmd/rpcdaemon/commands/zkevm_api.go b/cmd/rpcdaemon/commands/zkevm_api.go index e38d2fe63f7..c1a38ded839 100644 --- a/cmd/rpcdaemon/commands/zkevm_api.go +++ b/cmd/rpcdaemon/commands/zkevm_api.go @@ -11,25 +11,34 @@ import ( libcommon "github.com/gateway-fm/cdk-erigon-lib/common" "github.com/gateway-fm/cdk-erigon-lib/common/hexutility" "github.com/gateway-fm/cdk-erigon-lib/kv" + "github.com/gateway-fm/cdk-erigon-lib/kv/memdb" jsoniter "github.com/json-iterator/go" "github.com/holiman/uint256" "github.com/ledgerwatch/erigon/common/hexutil" "github.com/ledgerwatch/erigon/core" "github.com/ledgerwatch/erigon/core/rawdb" + "github.com/ledgerwatch/erigon/core/state" eritypes "github.com/ledgerwatch/erigon/core/types" + "github.com/ledgerwatch/erigon/core/types/accounts" "github.com/ledgerwatch/erigon/eth/ethconfig" + "github.com/ledgerwatch/erigon/eth/stagedsync" "github.com/ledgerwatch/erigon/eth/stagedsync/stages" "github.com/ledgerwatch/erigon/eth/tracers" "github.com/ledgerwatch/erigon/rpc" + smtDb "github.com/ledgerwatch/erigon/smt/pkg/db" + smt "github.com/ledgerwatch/erigon/smt/pkg/smt" + smtUtils "github.com/ledgerwatch/erigon/smt/pkg/utils" "github.com/ledgerwatch/erigon/turbo/rpchelper" "github.com/ledgerwatch/erigon/zk/hermez_db" "github.com/ledgerwatch/erigon/zk/legacy_executor_verifier" types "github.com/ledgerwatch/erigon/zk/rpcdaemon" "github.com/ledgerwatch/erigon/zk/sequencer" + zkStages "github.com/ledgerwatch/erigon/zk/stages" "github.com/ledgerwatch/erigon/zk/syncer" zktx "github.com/ledgerwatch/erigon/zk/tx" "github.com/ledgerwatch/erigon/zk/utils" + zkUtils "github.com/ledgerwatch/erigon/zk/utils" "github.com/ledgerwatch/erigon/zk/witness" "github.com/ledgerwatch/erigon/zkevm/hex" "github.com/ledgerwatch/erigon/zkevm/jsonrpc/client" @@ -1444,3 +1453,185 @@ func populateBatchDataSlimDetails(batches []*types.BatchDataSlim) (json.RawMessa return json.Marshal(jBatches) } + +// GetProof +func (zkapi *ZkEvmAPIImpl) GetProof(ctx context.Context, address common.Address, storageKeys []common.Hash, blockNrOrHash rpc.BlockNumberOrHash) (*accounts.SMTAccProofResult, error) { + api := zkapi.ethApi + + tx, err := api.db.BeginRo(ctx) + if err != nil { + return nil, err + } + defer tx.Rollback() + if api.historyV3(tx) { + return nil, fmt.Errorf("not supported by Erigon3") + } + + blockNr, _, _, err := rpchelper.GetBlockNumber(blockNrOrHash, tx, api.filters) + if err != nil { + return nil, err + } + + latestBlock, err := rpchelper.GetLatestBlockNumber(tx) + if err != nil { + return nil, err + } + + if latestBlock < blockNr { + // shouldn't happen, but check anyway + return nil, fmt.Errorf("block number is in the future latest=%d requested=%d", latestBlock, blockNr) + } + + batch := memdb.NewMemoryBatch(tx, api.dirs.Tmp) + defer batch.Rollback() + if err = zkUtils.PopulateMemoryMutationTables(batch); err != nil { + return nil, err + } + + if blockNr < latestBlock { + if latestBlock-blockNr > maxGetProofRewindBlockCount { + return nil, fmt.Errorf("requested block is too old, block must be within %d blocks of the head block number (currently %d)", maxGetProofRewindBlockCount, latestBlock) + } + unwindState := &stagedsync.UnwindState{UnwindPoint: blockNr} + stageState := &stagedsync.StageState{BlockNumber: latestBlock} + + interHashStageCfg := zkStages.StageZkInterHashesCfg(nil, true, true, false, api.dirs.Tmp, api._blockReader, nil, api.historyV3(tx), api._agg, nil) + + if err = zkStages.UnwindZkIntermediateHashesStage(unwindState, stageState, batch, interHashStageCfg, ctx, true); err != nil { + return nil, fmt.Errorf("unwind intermediate hashes: %w", err) + } + + if err != nil { + return nil, err + } + tx = batch + } + + reader, err := rpchelper.CreateStateReader(ctx, tx, blockNrOrHash, 0, api.filters, api.stateCache, api.historyV3(tx), "") + if err != nil { + return nil, err + } + + header, err := api._blockReader.HeaderByNumber(ctx, tx, blockNr) + if err != nil { + return nil, err + } + + tds := state.NewTrieDbState(header.Root, tx, blockNr, nil) + tds.SetResolveReads(true) + tds.StartNewBuffer() + tds.SetStateReader(reader) + + ibs := state.New(tds) + + ibs.GetBalance(address) + + for _, key := range storageKeys { + value := new(uint256.Int) + ibs.GetState(address, &key, value) + } + + rl, err := tds.ResolveSMTRetainList() + if err != nil { + return nil, err + } + + smtTrie := smt.NewRoSMT(smtDb.NewRoEriDb(tx)) + + proofs, err := smt.BuildProofs(smtTrie, rl, ctx) + if err != nil { + return nil, err + } + + stateRootNode := smtUtils.ScalarToRoot(new(big.Int).SetBytes(header.Root.Bytes())) + + if err != nil { + return nil, err + } + + balanceKey, err := smtUtils.KeyEthAddrBalance(address.String()) + if err != nil { + return nil, err + } + + nonceKey, err := smtUtils.KeyEthAddrNonce(address.String()) + if err != nil { + return nil, err + } + + codeHashKey, err := smtUtils.KeyContractCode(address.String()) + if err != nil { + return nil, err + } + + codeLengthKey, err := smtUtils.KeyContractLength(address.String()) + if err != nil { + return nil, err + } + + balanceProofs := smt.FilterProofs(proofs, balanceKey) + balanceBytes, err := smt.VerifyAndGetVal(stateRootNode, balanceProofs, balanceKey) + if err != nil { + return nil, fmt.Errorf("balance proof verification failed: %w", err) + } + + balance := new(big.Int).SetBytes(balanceBytes) + + nonceProofs := smt.FilterProofs(proofs, nonceKey) + nonceBytes, err := smt.VerifyAndGetVal(stateRootNode, nonceProofs, nonceKey) + if err != nil { + return nil, fmt.Errorf("nonce proof verification failed: %w", err) + } + nonce := new(big.Int).SetBytes(nonceBytes).Uint64() + + codeHashProofs := smt.FilterProofs(proofs, codeHashKey) + codeHashBytes, err := smt.VerifyAndGetVal(stateRootNode, codeHashProofs, codeHashKey) + if err != nil { + return nil, fmt.Errorf("code hash proof verification failed: %w", err) + } + codeHash := codeHashBytes + + codeLengthProofs := smt.FilterProofs(proofs, codeLengthKey) + codeLengthBytes, err := smt.VerifyAndGetVal(stateRootNode, codeLengthProofs, codeLengthKey) + if err != nil { + return nil, fmt.Errorf("code length proof verification failed: %w", err) + } + codeLength := new(big.Int).SetBytes(codeLengthBytes).Uint64() + + accProof := &accounts.SMTAccProofResult{ + Address: address, + Balance: (*hexutil.Big)(balance), + CodeHash: libcommon.BytesToHash(codeHash), + CodeLength: hexutil.Uint64(codeLength), + Nonce: hexutil.Uint64(nonce), + BalanceProof: balanceProofs, + NonceProof: nonceProofs, + CodeHashProof: codeHashProofs, + CodeLengthProof: codeLengthProofs, + StorageProof: make([]accounts.SMTStorageProofResult, 0), + } + + addressArrayBig := smtUtils.ScalarToArrayBig(smtUtils.ConvertHexToBigInt(address.String())) + for _, k := range storageKeys { + storageKey, err := smtUtils.KeyContractStorage(addressArrayBig, k.String()) + if err != nil { + return nil, err + } + storageProofs := smt.FilterProofs(proofs, storageKey) + + valueBytes, err := smt.VerifyAndGetVal(stateRootNode, storageProofs, storageKey) + if err != nil { + return nil, fmt.Errorf("storage proof verification failed: %w", err) + } + + value := new(big.Int).SetBytes(valueBytes) + + accProof.StorageProof = append(accProof.StorageProof, accounts.SMTStorageProofResult{ + Key: k, + Value: (*hexutil.Big)(value), + Proof: storageProofs, + }) + } + + return accProof, nil +} diff --git a/core/types/accounts/account_proof.go b/core/types/accounts/account_proof.go index 4041353ae46..5061f4be438 100644 --- a/core/types/accounts/account_proof.go +++ b/core/types/accounts/account_proof.go @@ -22,3 +22,22 @@ type StorProofResult struct { Value *hexutil.Big `json:"value"` Proof []hexutility.Bytes `json:"proof"` } + +type SMTAccProofResult struct { + Address libcommon.Address `json:"address"` + Balance *hexutil.Big `json:"balance"` + CodeHash libcommon.Hash `json:"codeHash"` + CodeLength hexutil.Uint64 `json:"codeLength"` + Nonce hexutil.Uint64 `json:"nonce"` + BalanceProof []hexutility.Bytes `json:"balanceProof"` + NonceProof []hexutility.Bytes `json:"nonceProof"` + CodeHashProof []hexutility.Bytes `json:"codeHashProof"` + CodeLengthProof []hexutility.Bytes `json:"codeLengthProof"` + StorageProof []SMTStorageProofResult `json:"storageProof"` +} + +type SMTStorageProofResult struct { + Key libcommon.Hash `json:"key"` + Value *hexutil.Big `json:"value"` + Proof []hexutility.Bytes `json:"proof"` +} diff --git a/smt/pkg/smt/proof.go b/smt/pkg/smt/proof.go new file mode 100644 index 00000000000..2e225be666b --- /dev/null +++ b/smt/pkg/smt/proof.go @@ -0,0 +1,207 @@ +package smt + +import ( + "bytes" + "context" + "fmt" + "math/big" + + "github.com/gateway-fm/cdk-erigon-lib/common/hexutility" + "github.com/ledgerwatch/erigon/smt/pkg/utils" + "github.com/ledgerwatch/erigon/turbo/trie" +) + +type SMTProofElement struct { + Path []byte + Proof []byte +} + +// FilterProofs filters the proofs to only include the ones that match the given key +func FilterProofs(proofs []*SMTProofElement, key utils.NodeKey) []hexutility.Bytes { + filteredProofs := make([]hexutility.Bytes, 0) + keyPath := key.GetPath() + + keyPathInBytes := make([]byte, len(keyPath)) + for i, v := range keyPath { + keyPathInBytes[i] = byte(v) + } + + for _, proof := range proofs { + if bytes.HasPrefix(keyPathInBytes, proof.Path) { + proofClone := make([]byte, len(proof.Proof)) + copy(proofClone, proof.Proof) + filteredProofs = append(filteredProofs, proofClone) + } + } + + return filteredProofs +} + +// BuildProofs builds proofs for multiple accounts and storage slots by traversing the SMT once. +// It efficiently generates proofs for all requested keys in a single pass. +// +// s: The read-only SMT to traverse +// rd: The retain decider that determines which nodes to include in the proof +// ctx: Context for cancellation +// +// Returns a slice of SMTProofElement containing the proof for each retained node, +// or an error if the traversal fails. +func BuildProofs(s *RoSMT, rd trie.RetainDecider, ctx context.Context) ([]*SMTProofElement, error) { + proofs := make([]*SMTProofElement, 0) + + root, err := s.DbRo.GetLastRoot() + if err != nil { + return nil, err + } + + action := func(prefix []byte, k utils.NodeKey, v utils.NodeValue12) (bool, error) { + retain := rd.Retain(prefix) + + if !retain { + return false, nil + } + + nodeBytes := make([]byte, 64) + utils.ArrayToScalar(v.Get0to4()[:]).FillBytes(nodeBytes[:32]) + utils.ArrayToScalar(v.Get4to8()[:]).FillBytes(nodeBytes[32:]) + + if v.IsFinalNode() { + nodeBytes = append(nodeBytes, 1) + } + + proofs = append(proofs, &SMTProofElement{ + Path: prefix, + Proof: nodeBytes, + }) + + if v.IsFinalNode() { + valHash := v.Get4to8() + v, err := s.DbRo.Get(*valHash) + if err != nil { + return false, err + } + + vInBytes := utils.ArrayBigToScalar(utils.BigIntArrayFromNodeValue8(v.GetNodeValue8())).Bytes() + + proofs = append(proofs, &SMTProofElement{ + Path: prefix, + Proof: vInBytes, + }) + + return false, nil + } + + return true, nil + } + + err = s.Traverse(ctx, root, action) + if err != nil { + return nil, err + } + + return proofs, nil +} + +// VerifyAndGetVal verifies a proof against a given state root and key, and returns the associated value if valid. +// +// Parameters: +// - stateRoot: The root node key to verify the proof against. +// - proof: A slice of byte slices representing the proof elements. +// - key: The node key for which the proof is being verified. +// +// Returns: +// - []byte: The value associated with the key. If the key does not exist in the proof, the value returned will be nil. +// - error: An error if the proof is invalid or verification fails. +// +// This function walks through the provided proof, verifying each step against the expected +// state root. It handles both branch and leaf nodes in the Sparse Merkle Tree. If the proof +// is valid and and value exists, it returns the value associated with the given key. If the proof is valid and +// the value does not exist, the value returned will be nil. If the proof is invalid at any point, an error is returned explaining where the verification failed. +// +// The function expects the proof to be in a specific format, with each element being either +// 64 bytes (for branch nodes) or 65 bytes (for leaf nodes, with the last byte indicating finality). +// It uses the utils package for various operations like hashing and key manipulation. +func VerifyAndGetVal(stateRoot utils.NodeKey, proof []hexutility.Bytes, key utils.NodeKey) ([]byte, error) { + if len(proof) == 0 { + return nil, fmt.Errorf("proof is empty") + } + + path := key.GetPath() + + curRoot := stateRoot + + foundValue := false + + for i := 0; i < len(proof); i++ { + isFinalNode := len(proof[i]) == 65 + + capacity := utils.BranchCapacity + + if isFinalNode { + capacity = utils.LeafCapacity + } + + leftChild := utils.ScalarToArray(big.NewInt(0).SetBytes(proof[i][:32])) + rightChild := utils.ScalarToArray(big.NewInt(0).SetBytes(proof[i][32:64])) + + leftChildNode := [4]uint64{leftChild[0], leftChild[1], leftChild[2], leftChild[3]} + rightChildNode := [4]uint64{rightChild[0], rightChild[1], rightChild[2], rightChild[3]} + + h, err := utils.Hash(utils.ConcatArrays4(leftChildNode, rightChildNode), capacity) + + if err != nil { + return nil, err + } + + if curRoot != h { + return nil, fmt.Errorf("root mismatch at level %d, expected %d, got %d", i, curRoot, h) + } + + if !isFinalNode { + if path[i] == 0 { + curRoot = leftChildNode + } else { + curRoot = rightChildNode + } + + // If the current root is zero, non-existence has been proven and we can return nil from here + if curRoot.IsZero() { + return nil, nil + } + } else { + joinedKey := utils.JoinKey(path[:i], leftChildNode) + if joinedKey.IsEqualTo(key) { + foundValue = true + curRoot = rightChildNode + break + } else { + // If the joined key is not equal to the input key, the proof is sufficient to verify the non-existence of the value, so we return nil from here + return nil, nil + } + } + } + + // If we've made it through the loop without finding the value, the proof is insufficient to verify the non-existence of the value + if !foundValue { + return nil, fmt.Errorf("proof is insufficient to verify the non-existence of the value") + } + + v := new(big.Int).SetBytes(proof[len(proof)-1]) + x := utils.ScalarToArrayBig(v) + nodeValue, err := utils.NodeValue8FromBigIntArray(x) + if err != nil { + return nil, err + } + + h, err := utils.Hash(nodeValue.ToUintArray(), utils.BranchCapacity) + + if err != nil { + return nil, err + } + + if h != curRoot { + return nil, fmt.Errorf("root mismatch at level %d, expected %d, got %d", len(proof)-1, curRoot, h) + } + + return proof[len(proof)-1], nil +} diff --git a/smt/pkg/smt/proof_test.go b/smt/pkg/smt/proof_test.go new file mode 100644 index 00000000000..f6d92be48a1 --- /dev/null +++ b/smt/pkg/smt/proof_test.go @@ -0,0 +1,236 @@ +package smt_test + +import ( + "bytes" + "context" + "fmt" + "reflect" + "strings" + "testing" + + libcommon "github.com/gateway-fm/cdk-erigon-lib/common" + "github.com/gateway-fm/cdk-erigon-lib/common/hexutility" + "github.com/holiman/uint256" + "github.com/ledgerwatch/erigon/smt/pkg/smt" + "github.com/ledgerwatch/erigon/smt/pkg/utils" + "github.com/ledgerwatch/erigon/turbo/trie" +) + +func TestFilterProofs(t *testing.T) { + tests := []struct { + name string + proofs []*smt.SMTProofElement + key utils.NodeKey + expected []hexutility.Bytes + }{ + { + name: "Matching proofs", + proofs: []*smt.SMTProofElement{ + {Path: []byte{0, 1}, Proof: []byte{1, 2, 3}}, + {Path: []byte{0, 1, 1}, Proof: []byte{4, 5, 6}}, + {Path: []byte{1, 1}, Proof: []byte{7, 8, 9}}, + }, + key: utils.NodeKey{0, 1, 1, 1}, + expected: []hexutility.Bytes{{1, 2, 3}, {4, 5, 6}}, + }, + { + name: "No matching proofs", + proofs: []*smt.SMTProofElement{ + {Path: []byte{1, 1}, Proof: []byte{1, 2, 3}}, + {Path: []byte{1, 0}, Proof: []byte{4, 5, 6}}, + }, + key: utils.NodeKey{0, 1, 1, 1}, + expected: []hexutility.Bytes{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := smt.FilterProofs(tt.proofs, tt.key) + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("FilterProofs() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestVerifyAndGetVal(t *testing.T) { + smtTrie, rl := prepareSMT(t) + + proofs, err := smt.BuildProofs(smtTrie.RoSMT, rl, context.Background()) + if err != nil { + t.Fatalf("BuildProofs() error = %v", err) + } + + contractAddress := libcommon.HexToAddress("0x71dd1027069078091B3ca48093B00E4735B20624") + a := utils.ConvertHexToBigInt(contractAddress.String()) + address := utils.ScalarToArrayBig(a) + + smtRoot, _ := smtTrie.RoSMT.DbRo.GetLastRoot() + if err != nil { + t.Fatalf("GetLastRoot() error = %v", err) + } + root := utils.ScalarToRoot(smtRoot) + + t.Run("Value exists and proof is correct", func(t *testing.T) { + storageKey, err := utils.KeyContractStorage(address, libcommon.HexToHash("0x5").String()) + + if err != nil { + t.Fatalf("KeyContractStorage() error = %v", err) + } + + storageProof := smt.FilterProofs(proofs, storageKey) + + val, err := smt.VerifyAndGetVal(root, storageProof, storageKey) + + if err != nil { + t.Fatalf("VerifyAndGetVal() error = %v", err) + } + + expected := uint256.NewInt(0xdeadbeef).Bytes() + + if !bytes.Equal(val, expected) { + t.Errorf("VerifyAndGetVal() = %v, want %v", val, expected) + } + }) + + t.Run("Value doesn't exist and non-existent proof is correct", func(t *testing.T) { + nonExistentRl := trie.NewRetainList(0) + nonExistentKeys := []utils.NodeKey{} + + // Fuzz with 1000 non-existent keys + for i := 0; i < 1000; i++ { + nonExistentKey, err := utils.KeyContractStorage( + address, + libcommon.HexToHash(fmt.Sprintf("0xdeadbeefabcd1234%d", i)).String(), + ) + + nonExistentKeys = append(nonExistentKeys, nonExistentKey) + + if err != nil { + t.Fatalf("KeyContractStorage() error = %v", err) + } + + nonExistentKeyPath := nonExistentKey.GetPath() + + keyBytes := make([]byte, 0, len(nonExistentKeyPath)) + + for _, v := range nonExistentKeyPath { + keyBytes = append(keyBytes, byte(v)) + } + + nonExistentRl.AddHex(keyBytes) + } + + nonExistentProofs, err := smt.BuildProofs(smtTrie.RoSMT, nonExistentRl, context.Background()) + if err != nil { + t.Fatalf("BuildProofs() error = %v", err) + } + + for _, key := range nonExistentKeys { + nonExistentProof := smt.FilterProofs(nonExistentProofs, key) + val, err := smt.VerifyAndGetVal(root, nonExistentProof, key) + + if err != nil { + t.Fatalf("VerifyAndGetVal() error = %v", err) + } + + if len(val) != 0 { + t.Errorf("VerifyAndGetVal() = %v, want empty value", val) + } + } + }) + + t.Run("Value doesn't exist but non-existent proof is insufficient", func(t *testing.T) { + nonExistentRl := trie.NewRetainList(0) + nonExistentKey, _ := utils.KeyContractStorage(address, libcommon.HexToHash("0x999").String()) + nonExistentKeyPath := nonExistentKey.GetPath() + keyBytes := make([]byte, 0, len(nonExistentKeyPath)) + + for _, v := range nonExistentKeyPath { + keyBytes = append(keyBytes, byte(v)) + } + + nonExistentRl.AddHex(keyBytes) + + nonExistentProofs, err := smt.BuildProofs(smtTrie.RoSMT, nonExistentRl, context.Background()) + if err != nil { + t.Fatalf("BuildProofs() error = %v", err) + } + + nonExistentProof := smt.FilterProofs(nonExistentProofs, nonExistentKey) + + // Verify the non-existent proof works + _, err = smt.VerifyAndGetVal(root, nonExistentProof, nonExistentKey) + + if err != nil { + t.Fatalf("VerifyAndGetVal() error = %v", err) + } + + // Only pass the first trie node in the proof + _, err = smt.VerifyAndGetVal(root, nonExistentProof[:1], nonExistentKey) + + if err == nil { + t.Errorf("VerifyAndGetVal() expected error, got nil") + } + }) + + t.Run("Value exists but proof is incorrect (first value corrupted)", func(t *testing.T) { + storageKey, _ := utils.KeyContractStorage(address, libcommon.HexToHash("0x5").String()) + storageProof := smt.FilterProofs(proofs, storageKey) + + // Corrupt the proof by changing a byte + if len(storageProof) > 0 && len(storageProof[0]) > 0 { + storageProof[0][0] ^= 0xFF // Flip all bits in the first byte + } + + _, err := smt.VerifyAndGetVal(root, storageProof, storageKey) + + if err == nil { + if err == nil || !strings.Contains(err.Error(), "root mismatch at level 0") { + t.Errorf("VerifyAndGetVal() expected error containing 'root mismatch at level 0', got %v", err) + } + } + }) + + t.Run("Value exists but proof is incorrect (last value corrupted)", func(t *testing.T) { + storageKey, _ := utils.KeyContractStorage(address, libcommon.HexToHash("0x5").String()) + storageProof := smt.FilterProofs(proofs, storageKey) + + // Corrupt the proof by changing the last byte of the last proof element + if len(storageProof) > 0 { + lastProof := storageProof[len(storageProof)-1] + if len(lastProof) > 0 { + lastProof[len(lastProof)-1] ^= 0xFF // Flip all bits in the last byte + } + } + + _, err := smt.VerifyAndGetVal(root, storageProof, storageKey) + + if err == nil { + if err == nil || !strings.Contains(err.Error(), fmt.Sprintf("root mismatch at level %d", len(storageProof)-1)) { + t.Errorf("VerifyAndGetVal() expected error containing 'root mismatch at level %d', got %v", len(storageProof)-1, err) + } + } + }) + + t.Run("Value exists but proof is insufficient", func(t *testing.T) { + storageKey, _ := utils.KeyContractStorage(address, libcommon.HexToHash("0x5").String()) + storageProof := smt.FilterProofs(proofs, storageKey) + + // Modify the proof to claim the value doesn't exist + if len(storageProof) > 0 { + storageProof = storageProof[:len(storageProof)-2] + } + + val, err := smt.VerifyAndGetVal(root, storageProof, storageKey) + + if err == nil || !strings.Contains(err.Error(), "insufficient") { + t.Errorf("VerifyAndGetVal() expected error containing 'insufficient', got %v", err) + } + + if len(val) != 0 { + t.Errorf("VerifyAndGetVal() = %v, want empty value", val) + } + }) +} diff --git a/turbo/rpchelper/rpc_block.go b/turbo/rpchelper/rpc_block.go index f0fa1b04340..90d7e0c40ba 100644 --- a/turbo/rpchelper/rpc_block.go +++ b/turbo/rpchelper/rpc_block.go @@ -26,7 +26,7 @@ func GetLatestBlockNumber(tx kv.Tx) (uint64, error) { } } - blockNum, err := stages.GetStageProgress(tx, stages.Execution) + blockNum, err := stages.GetStageProgress(tx, stages.Finish) if err != nil { return 0, fmt.Errorf("getting latest block number: %w", err) } From 31dc6f6992058d9e6833af978860bd838b235823 Mon Sep 17 00:00:00 2001 From: Jerry Date: Thu, 29 Aug 2024 09:54:51 -0700 Subject: [PATCH 2/2] Delete rolled back L1 sequence on sequence rollback event (#1022) * Delete rolled back L1 sequence info on sequence rollback event * Address CR comments --- eth/backend.go | 1 + zk/contracts/l1_contracts.go | 1 + zk/hermez_db/db.go | 21 +++++++++++++++++++++ zk/stages/stage_l1syncer.go | 11 +++++++++++ 4 files changed, 34 insertions(+) diff --git a/eth/backend.go b/eth/backend.go index aad639d4c4c..6dbfd0c6506 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -799,6 +799,7 @@ func New(stack *node.Node, config *ethconfig.Config) (*Ethereum, error) { seqAndVerifTopics := [][]libcommon.Hash{{ contracts.SequencedBatchTopicPreEtrog, contracts.SequencedBatchTopicEtrog, + contracts.RollbackBatchesTopic, contracts.VerificationTopicPreEtrog, contracts.VerificationTopicEtrog, contracts.VerificationValidiumTopicEtrog, diff --git a/zk/contracts/l1_contracts.go b/zk/contracts/l1_contracts.go index 7c4bb7f3d80..7ba052d7872 100644 --- a/zk/contracts/l1_contracts.go +++ b/zk/contracts/l1_contracts.go @@ -16,4 +16,5 @@ var ( AddNewRollupTypeTopic = common.HexToHash("0xa2970448b3bd66ba7e524e7b2a5b9cf94fa29e32488fb942afdfe70dd4b77b52") CreateNewRollupTopic = common.HexToHash("0x194c983456df6701c6a50830b90fe80e72b823411d0d524970c9590dc277a641") UpdateRollupTopic = common.HexToHash("0xf585e04c05d396901170247783d3e5f0ee9c1df23072985b50af089f5e48b19d") + RollbackBatchesTopic = common.HexToHash("0x1125aaf62d132d8e2d02005114f8fc360ff204c3105e4f1a700a1340dc55d5b1") ) diff --git a/zk/hermez_db/db.go b/zk/hermez_db/db.go index 24de3570578..b3e6319a9d4 100644 --- a/zk/hermez_db/db.go +++ b/zk/hermez_db/db.go @@ -528,6 +528,27 @@ func (db *HermezDb) WriteSequence(l1BlockNo, batchNo uint64, l1TxHash, stateRoot return db.tx.Put(L1SEQUENCES, ConcatKey(l1BlockNo, batchNo), val) } +// RollbackSequences deletes the sequences up to the given batch number +func (db *HermezDb) RollbackSequences(batchNo uint64) error { + for { + latestSequence, err := db.GetLatestSequence() + if err != nil { + return err + } + + if latestSequence == nil || latestSequence.BatchNo <= batchNo { + break + } + + err = db.tx.Delete(L1SEQUENCES, ConcatKey(latestSequence.L1BlockNo, latestSequence.BatchNo)) + if err != nil { + return err + } + } + + return nil +} + func (db *HermezDb) TruncateSequences(l2BlockNo uint64) error { batchNo, err := db.GetBatchNoByL2Block(l2BlockNo) if err != nil { diff --git a/zk/stages/stage_l1syncer.go b/zk/stages/stage_l1syncer.go index 513351e997a..8362247b8a4 100644 --- a/zk/stages/stage_l1syncer.go +++ b/zk/stages/stage_l1syncer.go @@ -149,6 +149,13 @@ Loop: highestWrittenL1BlockNo = info.L1BlockNo } newSequencesCount++ + case logRollbackBatches: + if err := hermezDb.RollbackSequences(info.BatchNo); err != nil { + return fmt.Errorf("failed to write rollback sequence, %w", err) + } + if info.L1BlockNo > highestWrittenL1BlockNo { + highestWrittenL1BlockNo = info.L1BlockNo + } case logVerify: if info.BatchNo > highestVerification.BatchNo { highestVerification = info @@ -222,6 +229,7 @@ var ( logSequence BatchLogType = 1 logVerify BatchLogType = 2 logL1InfoTreeUpdate BatchLogType = 4 + logRollbackBatches BatchLogType = 5 logIncompatible BatchLogType = 100 ) @@ -265,6 +273,9 @@ func parseLogType(l1RollupId uint64, log *ethTypes.Log) (l1BatchInfo types.L1Bat } case contracts.UpdateL1InfoTreeTopic: batchLogType = logL1InfoTreeUpdate + case contracts.RollbackBatchesTopic: + batchLogType = logRollbackBatches + batchNum = new(big.Int).SetBytes(log.Topics[1].Bytes()).Uint64() default: batchLogType = logUnknown batchNum = 0