diff --git a/geth-utils/gethutil/mpt/oracle/prefetch.go b/geth-utils/gethutil/mpt/oracle/prefetch.go index 34907912b1..4b5a6d1d99 100644 --- a/geth-utils/gethutil/mpt/oracle/prefetch.go +++ b/geth-utils/gethutil/mpt/oracle/prefetch.go @@ -81,7 +81,8 @@ var RemoteUrl = "https://mainnet.infura.io/v3/9aa3d95b3bc440fa88ea12eaa4456161" var LocalUrl = "http://localhost:8545" // For generating special tests for MPT circuit: -var PreventHashingInSecureTrie = false +var PreventHashingInSecureTrie = false // storage +var AccountPreventHashingInSecureTrie = false func toFilename(key string) string { return fmt.Sprintf("/tmp/eth/json_%s", key) diff --git a/geth-utils/gethutil/mpt/state/database.go b/geth-utils/gethutil/mpt/state/database.go index 52cdf1b596..37acd97c7b 100644 --- a/geth-utils/gethutil/mpt/state/database.go +++ b/geth-utils/gethutil/mpt/state/database.go @@ -52,7 +52,7 @@ func (db *Database) CopyTrie(t Trie) Trie { // OpenTrie opens the main account trie at a specific root hash. func (db *Database) OpenTrie(root common.Hash) (Trie, error) { - tr, err := trie.NewSecure(root, db.db) + tr, err := trie.NewSecure(root, db.db, false) if err != nil { return nil, err } @@ -62,7 +62,7 @@ func (db *Database) OpenTrie(root common.Hash) (Trie, error) { // OpenStorageTrie opens the storage trie of an account. func (db *Database) OpenStorageTrie(addrHash, root common.Hash) (Trie, error) { //return SimpleTrie{db.BlockNumber, root, true, addrHash}, nil - tr, err := trie.NewSecure(root, db.db) + tr, err := trie.NewSecure(root, db.db, true) if err != nil { return nil, err } diff --git a/geth-utils/gethutil/mpt/state/statedb.go b/geth-utils/gethutil/mpt/state/statedb.go index 07ad5311ec..12414c504d 100644 --- a/geth-utils/gethutil/mpt/state/statedb.go +++ b/geth-utils/gethutil/mpt/state/statedb.go @@ -298,8 +298,8 @@ func (s *StateDB) GetState(addr common.Address, hash common.Hash) common.Hash { // GetProof returns the Merkle proof for a given account. func (s *StateDB) GetProof(addr common.Address) ([][]byte, []byte, [][]byte, bool, bool, error) { - var newAddr common.Hash - if oracle.PreventHashingInSecureTrie { + newAddr := crypto.Keccak256Hash(addr.Bytes()) + if oracle.AccountPreventHashingInSecureTrie { bytes := append(addr.Bytes(), []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}...) newAddr = common.BytesToHash(bytes) } @@ -317,7 +317,7 @@ func (s *StateDB) GetProofByHash(addrHash common.Hash) ([][]byte, []byte, [][]by func (s *StateDB) GetStorageProof(a common.Address, key common.Hash) ([][]byte, []byte, [][]byte, bool, bool, error) { var proof proofList newAddr := a - if oracle.PreventHashingInSecureTrie { + if oracle.AccountPreventHashingInSecureTrie { bytes := append(a.Bytes(), []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}...) newAddr = common.BytesToAddress(bytes) } @@ -545,7 +545,7 @@ func (s *StateDB) updateStateObject(obj *stateObject) { panic(fmt.Errorf("can't encode object at %x: %v", addr[:], err)) } - if !oracle.PreventHashingInSecureTrie { + if !oracle.AccountPreventHashingInSecureTrie { if err = s.trie.TryUpdateAlwaysHash(addr[:], data); err != nil { s.setError(fmt.Errorf("updateStateObject (%x) error: %v", addr[:], err)) } diff --git a/geth-utils/gethutil/mpt/trie/secure_trie.go b/geth-utils/gethutil/mpt/trie/secure_trie.go index 2831de5bd5..ae34f1a8db 100644 --- a/geth-utils/gethutil/mpt/trie/secure_trie.go +++ b/geth-utils/gethutil/mpt/trie/secure_trie.go @@ -40,6 +40,7 @@ type SecureTrie struct { hashKeyBuf [common.HashLength]byte secKeyCache map[string][]byte secKeyCacheOwner *SecureTrie // Pointer to self, replace the key cache on mismatch + isStorageTrie bool } // NewSecure creates a trie with an existing root node from a backing database @@ -53,7 +54,7 @@ type SecureTrie struct { // Loaded nodes are kept around until their 'cache generation' expires. // A new cache generation is created by each call to Commit. // cachelimit sets the number of past cache generations to keep. -func NewSecure(root common.Hash, db *Database) (*SecureTrie, error) { +func NewSecure(root common.Hash, db *Database, isStorageTrie bool) (*SecureTrie, error) { if db == nil { panic("trie.NewSecure called without a database") } @@ -61,7 +62,7 @@ func NewSecure(root common.Hash, db *Database) (*SecureTrie, error) { if err != nil { return nil, err } - return &SecureTrie{trie: *trie}, nil + return &SecureTrie{trie: *trie, isStorageTrie: isStorageTrie}, nil } // Get returns the value for key stored in the trie. @@ -202,7 +203,8 @@ func (t *SecureTrie) NodeIterator(start []byte) NodeIterator { // The caller must not hold onto the return value because it will become // invalid on the next call to hashKey or secKey. func (t *SecureTrie) hashKey(key []byte) []byte { - if !oracle.PreventHashingInSecureTrie { + preventHashing := (oracle.PreventHashingInSecureTrie && t.isStorageTrie) || (oracle.AccountPreventHashingInSecureTrie && !t.isStorageTrie) + if !preventHashing { h := NewHasher(false) h.sha.Reset() h.sha.Write(key) diff --git a/geth-utils/gethutil/mpt/witness/nodes.go b/geth-utils/gethutil/mpt/witness/nodes.go index d1d40ee83b..74a0f9d395 100644 --- a/geth-utils/gethutil/mpt/witness/nodes.go +++ b/geth-utils/gethutil/mpt/witness/nodes.go @@ -180,7 +180,7 @@ type Node struct { func GetStartNode(proofType string, sRoot, cRoot common.Hash, specialTest byte) Node { s := StartNode{ - DisablePreimageCheck: oracle.PreventHashingInSecureTrie || specialTest == 5, + DisablePreimageCheck: oracle.PreventHashingInSecureTrie || oracle.AccountPreventHashingInSecureTrie || specialTest == 5, ProofType: proofType, } var values [][]byte diff --git a/geth-utils/gethutil/mpt/witness/prepare_witness.go b/geth-utils/gethutil/mpt/witness/prepare_witness.go index b36cada0da..bda8d2e2bf 100644 --- a/geth-utils/gethutil/mpt/witness/prepare_witness.go +++ b/geth-utils/gethutil/mpt/witness/prepare_witness.go @@ -70,7 +70,7 @@ func obtainAccountProofAndConvertToWitness(tMod TrieModification, statedb *state addr := tMod.Address addrh := crypto.Keccak256(addr.Bytes()) - if oracle.PreventHashingInSecureTrie { + if oracle.AccountPreventHashingInSecureTrie { addrh = addr.Bytes() addrh = append(addrh, []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}...) addr = common.BytesToAddress(addrh) @@ -180,7 +180,7 @@ func obtainTwoProofsAndConvertToWitness(trieModifications []TrieModification, st addr := tMod.Address addrh := crypto.Keccak256(addr.Bytes()) - if oracle.PreventHashingInSecureTrie { + if oracle.AccountPreventHashingInSecureTrie { addrh = addr.Bytes() addrh = append(addrh, []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}...) addr = common.BytesToAddress(addrh)