diff --git a/smt/pkg/db/mdbx.go b/smt/pkg/db/mdbx.go index 18351d0fe6c..adca963eaac 100644 --- a/smt/pkg/db/mdbx.go +++ b/smt/pkg/db/mdbx.go @@ -2,6 +2,7 @@ package db import ( "context" + "encoding/hex" "math/big" "fmt" @@ -304,6 +305,19 @@ func (m *EriRoDb) GetCode(codeHash []byte) ([]byte, error) { return data, nil } +func (m *EriDb) AddCode(code []byte) error { + codeHash := utils.HashContractBytecode(hex.EncodeToString(code)) + + codeHashBytes, err := hex.DecodeString(strings.TrimPrefix(codeHash, "0x")) + if err != nil { + return err + } + + codeHashBytes = utils.ResizeHashTo32BytesByPrefixingWithZeroes(codeHashBytes) + + return m.tx.Put(kv.Code, codeHashBytes, code) +} + func (m *EriRoDb) PrintDb() { err := m.kvTxRo.ForEach(TableSmt, []byte{}, func(k, v []byte) error { println(string(k), string(v)) diff --git a/smt/pkg/smt/entity_storage.go b/smt/pkg/smt/entity_storage.go index e33a6d06357..566248f2c5a 100644 --- a/smt/pkg/smt/entity_storage.go +++ b/smt/pkg/smt/entity_storage.go @@ -14,30 +14,55 @@ import ( "github.com/ledgerwatch/erigon/smt/pkg/utils" ) +// SetAccountState sets the balance and nonce of an account func (s *SMT) SetAccountState(ethAddr string, balance, nonce *big.Int) (*big.Int, error) { + _, err := s.SetAccountBalance(ethAddr, balance) + if err != nil { + return nil, err + } + + auxOut, err := s.SetAccountNonce(ethAddr, nonce) + if err != nil { + return nil, err + } + + return auxOut, nil +} + +// SetAccountBalance sets the balance of an account +func (s *SMT) SetAccountBalance(ethAddr string, balance *big.Int) (*big.Int, error) { keyBalance := utils.KeyEthAddrBalance(ethAddr) - keyNonce := utils.KeyEthAddrNonce(ethAddr) - if _, err := s.InsertKA(keyBalance, balance); err != nil { + auxRes, err := s.InsertKA(keyBalance, balance) + if err != nil { return nil, err } ks := utils.EncodeKeySource(utils.KEY_BALANCE, utils.ConvertHexToAddress(ethAddr), common.Hash{}) - if err := s.Db.InsertKeySource(keyBalance, ks); err != nil { + err = s.Db.InsertKeySource(keyBalance, ks) + if err != nil { return nil, err } + return auxRes.NewRootScalar.ToBigInt(), err +} + +// SetAccountNonce sets the nonce of an account +func (s *SMT) SetAccountNonce(ethAddr string, nonce *big.Int) (*big.Int, error) { + keyNonce := utils.KeyEthAddrNonce(ethAddr) + auxRes, err := s.InsertKA(keyNonce, nonce) if err != nil { return nil, err } - ks = utils.EncodeKeySource(utils.KEY_NONCE, utils.ConvertHexToAddress(ethAddr), common.Hash{}) - if err := s.Db.InsertKeySource(keyNonce, ks); err != nil { + ks := utils.EncodeKeySource(utils.KEY_NONCE, utils.ConvertHexToAddress(ethAddr), common.Hash{}) + err = s.Db.InsertKeySource(keyNonce, ks) + if err != nil { return nil, err } - return auxRes.NewRootScalar.ToBigInt(), nil + return auxRes.NewRootScalar.ToBigInt(), err } func (s *SMT) SetAccountStorage(addr libcommon.Address, acc *accounts.Account) error { @@ -304,7 +329,7 @@ func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[l for k, v := range storage { keyStoragePosition := utils.KeyContractStorage(ethAddrBigIngArray, k) - valueBigInt := convertStrintToBigInt(v) + valueBigInt := convertStringToBigInt(v) keysBatchStorage = append(keysBatchStorage, &keyStoragePosition) if valuesBatchStorage, isDelete, err = appendToValuesBatchStorageBigInt(valuesBatchStorage, valueBigInt); err != nil { return nil, nil, err @@ -341,7 +366,7 @@ func (s *SMT) DeleteKeySource(nodeKey *utils.NodeKey) error { } func calcHashVal(v string) (*utils.NodeValue8, [4]uint64, error) { - val := convertStrintToBigInt(v) + val := convertStringToBigInt(v) x := utils.ScalarToArrayBig(val) value, err := utils.NodeValue8FromBigIntArray(x) @@ -354,10 +379,10 @@ func calcHashVal(v string) (*utils.NodeValue8, [4]uint64, error) { return value, h, nil } -func convertStrintToBigInt(v string) *big.Int { +func convertStringToBigInt(v string) *big.Int { base := 10 if strings.HasPrefix(v, "0x") { - v = v[2:] + v = strings.TrimPrefix(v, "0x") base = 16 } diff --git a/smt/pkg/smt/smt.go b/smt/pkg/smt/smt.go index 1c2a8a386e4..50d0221916d 100644 --- a/smt/pkg/smt/smt.go +++ b/smt/pkg/smt/smt.go @@ -23,6 +23,7 @@ type DB interface { InsertKeySource(key utils.NodeKey, value []byte) error DeleteKeySource(key utils.NodeKey) error InsertHashKey(key utils.NodeKey, value utils.NodeKey) error + AddCode(code []byte) error DeleteHashKey(key utils.NodeKey) error Delete(string) error DeleteByNodeKey(key utils.NodeKey) error @@ -297,7 +298,9 @@ func (s *SMT) insert(k utils.NodeKey, v utils.NodeValue8, newValH [4]uint64, old if err != nil { return nil, err } - s.Db.InsertHashKey(newLeafHash, k) + if err := s.Db.InsertHashKey(newLeafHash, k); err != nil { + return nil, err + } if level >= 0 { for j := 0; j < 4; j++ { siblings[level][keys[level]*4+j] = new(big.Int).SetUint64(newLeafHash[j]) @@ -649,7 +652,7 @@ func (s *SMT) updateDepth(newDepth int) { newDepthAsByte := byte(newDepth & 0xFF) if oldDepth < newDepthAsByte { - s.Db.SetDepth(newDepthAsByte) + _ = s.Db.SetDepth(newDepthAsByte) } } @@ -728,3 +731,90 @@ func (s *RoSMT) traverseAndMark(ctx context.Context, node *big.Int, visited Visi return true, nil }) } + +// InsertHashNode inserts a hash node into the SMT. The SMT should not contain any other leaf nodes with the same path prefix. Otherwise, the new root hash will be incorrect. +// TODO: Support insertion of hash nodes even if there are leaf nodes with the same path prefix in SMT. +func (s *SMT) InsertHashNode(path []int, hash *big.Int) (*big.Int, error) { + s.clearUpMutex.Lock() + defer s.clearUpMutex.Unlock() + + or, err := s.getLastRoot() + if err != nil { + return nil, err + } + + h := utils.ScalarToArray(hash) + + var nodeHash [4]uint64 + copy(nodeHash[:], h[:4]) + + lastRoot, err := s.insertHashNode(path, nodeHash, or) + if err != nil { + return nil, err + } + + if err = s.setLastRoot(lastRoot); err != nil { + return nil, err + } + + return lastRoot.ToBigInt(), nil +} + +func (s *SMT) insertHashNode(path []int, hash [4]uint64, root utils.NodeKey) (utils.NodeKey, error) { + if len(path) == 0 { + newValHBig := utils.ArrayToScalar(hash[:]) + v := utils.ScalarToNodeValue8(newValHBig) + + err := s.hashSave(v.ToUintArray(), utils.LeafCapacity, hash) + if err != nil { + return utils.NodeKey{}, err + } + + return hash, nil + } + + rootVal := utils.NodeValue12{} + + if !root.IsZero() { + v, err := s.Db.Get(root) + if err != nil { + return utils.NodeKey{}, err + } + + rootVal = v + } + + childIndex := path[0] + + childOldRoot := rootVal[childIndex*4 : childIndex*4+4] + + childNewRoot, err := s.insertHashNode(path[1:], hash, utils.NodeKeyFromBigIntArray(childOldRoot)) + + if err != nil { + return utils.NodeKey{}, err + } + + var newIn [8]uint64 + + emptyRootVal := utils.NodeValue12{} + + if childIndex == 0 { + var sibling [4]uint64 + if rootVal == emptyRootVal { + sibling = [4]uint64{0, 0, 0, 0} + } else { + sibling = *rootVal.Get4to8() + } + newIn = utils.ConcatArrays4(childNewRoot, sibling) + } else { + var sibling [4]uint64 + if rootVal == emptyRootVal { + sibling = [4]uint64{0, 0, 0, 0} + } else { + sibling = *rootVal.Get0to4() + } + newIn = utils.ConcatArrays4(sibling, childNewRoot) + } + + return s.hashcalcAndSave(newIn, utils.BranchCapacity) +} diff --git a/smt/pkg/smt/smt_state_reader.go b/smt/pkg/smt/smt_state_reader.go new file mode 100644 index 00000000000..3c68a431b0f --- /dev/null +++ b/smt/pkg/smt/smt_state_reader.go @@ -0,0 +1,179 @@ +package smt + +import ( + "bytes" + "context" + "math/big" + + "github.com/holiman/uint256" + libcommon "github.com/ledgerwatch/erigon-lib/common" + "github.com/ledgerwatch/erigon/core/types/accounts" + "github.com/ledgerwatch/erigon/smt/pkg/utils" + "github.com/ledgerwatch/erigon/zkevm/log" +) + +// ReadAccountData reads account data from the SMT +func (s *SMT) ReadAccountData(address libcommon.Address) (*accounts.Account, error) { + account := accounts.Account{} + + balance, err := s.GetAccountBalance(address) + if err != nil { + return nil, err + } + account.Balance = *balance + + nonce, err := s.GetAccountNonce(address) + if err != nil { + return nil, err + } + account.Nonce = nonce.Uint64() + + codeHash, err := s.GetAccountCodeHash(address) + if err != nil { + return nil, err + } + account.CodeHash = codeHash + + account.Root = libcommon.Hash{} + + return &account, nil +} + +// ReadAccountStorage reads account storage from the SMT (not implemented for SMT) +func (s *SMT) ReadAccountStorage(address libcommon.Address, incarnation uint64, key *libcommon.Hash) ([]byte, error) { + value, err := s.getValue(0, address, key) + if err != nil { + return []byte{}, err + } + + return value, nil +} + +// ReadAccountCode reads account code from the SMT +func (s *SMT) ReadAccountCode(address libcommon.Address, incarnation uint64, codeHash libcommon.Hash) ([]byte, error) { + code, err := s.Db.GetCode(codeHash.Bytes()) + if err != nil { + return []byte{}, err + } + + return code, nil +} + +// ReadAccountCodeSize reads account code size from the SMT +func (s *SMT) ReadAccountCodeSize(address libcommon.Address, incarnation uint64, codeHash libcommon.Hash) (int, error) { + valueInBytes, err := s.getValue(utils.SC_LENGTH, address, nil) + if err != nil { + return 0, err + } + + sizeBig := big.NewInt(0).SetBytes(valueInBytes) + + return int(sizeBig.Int64()), nil +} + +// ReadAccountIncarnation reads account incarnation from the SMT (not implemented for SMT) +func (s *SMT) ReadAccountIncarnation(address libcommon.Address) (uint64, error) { + return 0, nil +} + +// GetAccountBalance returns the balance of an account from the SMT +func (s *SMT) GetAccountBalance(address libcommon.Address) (*uint256.Int, error) { + balance := uint256.NewInt(0) + + valueInBytes, err := s.getValue(utils.KEY_BALANCE, address, nil) + if err != nil { + log.Error("error getting balance", "error", err) + return nil, err + } + balance.SetBytes(valueInBytes) + + return balance, nil +} + +// GetAccountNonce returns the nonce of an account from the SMT +func (s *SMT) GetAccountNonce(address libcommon.Address) (*uint256.Int, error) { + nonce := uint256.NewInt(0) + + valueInBytes, err := s.getValue(utils.KEY_NONCE, address, nil) + if err != nil { + log.Error("error getting nonce", "error", err) + return nil, err + } + nonce.SetBytes(valueInBytes) + + return nonce, nil +} + +// GetAccountCodeHash returns the code hash of an account from the SMT +func (s *SMT) GetAccountCodeHash(address libcommon.Address) (libcommon.Hash, error) { + codeHash := libcommon.Hash{} + + valueInBytes, err := s.getValue(utils.SC_CODE, address, nil) + if err != nil { + log.Error("error getting codehash", "error", err) + return libcommon.Hash{}, err + } + codeHash.SetBytes(valueInBytes) + + return codeHash, nil +} + +// getValue returns the value of a key from SMT by traversing the SMT +func (s *SMT) getValue(key int, address libcommon.Address, storageKey *libcommon.Hash) ([]byte, error) { + var kn utils.NodeKey + + if storageKey == nil { + kn = utils.Key(address.String(), key) + } else { + a := utils.ConvertHexToBigInt(address.String()) + add := utils.ScalarToArrayBig(a) + + kn = utils.KeyContractStorage(add, storageKey.String()) + } + + return s.getValueInBytes(kn) +} + +// getValueInBytes returns the value of a key from SMT in bytes by traversing the SMT +func (s *SMT) getValueInBytes(nodeKey utils.NodeKey) ([]byte, error) { + value := []byte{} + + keyPath := nodeKey.GetPath() + + keyPathBytes := make([]byte, 0) + for _, k := range keyPath { + keyPathBytes = append(keyPathBytes, byte(k)) + } + + action := func(prefix []byte, k utils.NodeKey, v utils.NodeValue12) (bool, error) { + if !bytes.HasPrefix(keyPathBytes, prefix) { + return false, nil + } + + if v.IsFinalNode() { + valHash := v.Get4to8() + v, err := s.Db.Get(*valHash) + if err != nil { + return false, err + } + vInBytes := utils.ArrayBigToScalar(utils.BigIntArrayFromNodeValue8(v.GetNodeValue8())).Bytes() + + value = vInBytes + return false, nil + } + + return true, nil + } + + root, err := s.Db.GetLastRoot() + if err != nil { + return nil, err + } + + err = s.Traverse(context.Background(), root, action) + if err != nil { + return nil, err + } + + return value, nil +} diff --git a/smt/pkg/smt/witness.go b/smt/pkg/smt/witness.go index ce64d08107e..5e39bc5b973 100644 --- a/smt/pkg/smt/witness.go +++ b/smt/pkg/smt/witness.go @@ -2,12 +2,16 @@ package smt import ( "context" + "fmt" + "math/big" libcommon "github.com/ledgerwatch/erigon-lib/common" "github.com/ledgerwatch/erigon/smt/pkg/utils" "github.com/ledgerwatch/erigon/turbo/trie" + "github.com/status-im/keycard-go/hexutils" ) +// BuildWitness creates a witness from the SMT func BuildWitness(s *SMT, rd trie.RetainDecider, ctx context.Context) (*trie.Witness, error) { operands := make([]trie.WitnessOperator, 0) @@ -33,7 +37,7 @@ func BuildWitness(s *SMT, rd trie.RetainDecider, ctx context.Context) (*trie.Wit This algorithm adds a little bit more nodes to the witness but it ensures that all requiring nodes are included. */ - retain := true + retain := false prefixLen := len(prefix) if prefixLen > 0 { @@ -112,3 +116,166 @@ func BuildWitness(s *SMT, rd trie.RetainDecider, ctx context.Context) (*trie.Wit return trie.NewWitness(operands), err } + +// BuildSMTfromWitness builds SMT from witness +func BuildSMTfromWitness(w *trie.Witness) (*SMT, error) { + // using memdb + s := NewSMT(nil, false) + + balanceMap := make(map[string]*big.Int) + nonceMap := make(map[string]*big.Int) + contractMap := make(map[string]string) + storageMap := make(map[string]map[string]string) + + path := make([]int, 0) + + firstNode := true + NodeChildCountMap := make(map[string]uint32) + NodesBranchValueMap := make(map[string]uint32) + + type nodeHash struct { + path []int + hash libcommon.Hash + } + + nodeHashes := make([]nodeHash, 0) + + for i, operator := range w.Operators { + switch op := operator.(type) { + case *trie.OperatorSMTLeafValue: + valScaler := big.NewInt(0).SetBytes(op.Value) + addr := libcommon.BytesToAddress(op.Address) + + switch op.NodeType { + case utils.KEY_BALANCE: + balanceMap[addr.String()] = valScaler + + case utils.KEY_NONCE: + nonceMap[addr.String()] = valScaler + + case utils.SC_STORAGE: + if _, ok := storageMap[addr.String()]; !ok { + storageMap[addr.String()] = make(map[string]string) + } + + stKey := hexutils.BytesToHex(op.StorageKey) + if len(stKey) > 0 { + stKey = fmt.Sprintf("0x%s", stKey) + } + + storageMap[addr.String()][stKey] = valScaler.String() + } + + path = path[:len(path)-1] + NodeChildCountMap[intArrayToString(path)] += 1 + + for len(path) != 0 && NodeChildCountMap[intArrayToString(path)] == NodesBranchValueMap[intArrayToString(path)] { + path = path[:len(path)-1] + } + if NodeChildCountMap[intArrayToString(path)] < NodesBranchValueMap[intArrayToString(path)] { + path = append(path, 1) + } + + case *trie.OperatorCode: + addr := libcommon.BytesToAddress(w.Operators[i+1].(*trie.OperatorSMTLeafValue).Address) + + code := hexutils.BytesToHex(op.Code) + if len(code) > 0 { + if err := s.Db.AddCode(hexutils.HexToBytes(code)); err != nil { + return nil, err + } + code = fmt.Sprintf("0x%s", code) + } + + contractMap[addr.String()] = code + + case *trie.OperatorBranch: + if firstNode { + firstNode = false + } else { + NodeChildCountMap[intArrayToString(path[:len(path)-1])] += 1 + } + + switch op.Mask { + case 1: + NodesBranchValueMap[intArrayToString(path)] = 1 + path = append(path, 0) + case 2: + NodesBranchValueMap[intArrayToString(path)] = 1 + path = append(path, 1) + case 3: + NodesBranchValueMap[intArrayToString(path)] = 2 + path = append(path, 0) + } + + case *trie.OperatorHash: + pathCopy := make([]int, len(path)) + copy(pathCopy, path) + nodeHashes = append(nodeHashes, nodeHash{path: pathCopy, hash: op.Hash}) + + path = path[:len(path)-1] + NodeChildCountMap[intArrayToString(path)] += 1 + + for len(path) != 0 && NodeChildCountMap[intArrayToString(path)] == NodesBranchValueMap[intArrayToString(path)] { + path = path[:len(path)-1] + } + if NodeChildCountMap[intArrayToString(path)] < NodesBranchValueMap[intArrayToString(path)] { + path = append(path, 1) + } + + default: + // Unsupported operator type + return nil, fmt.Errorf("unsupported operator type: %T", op) + } + } + + for _, nodeHash := range nodeHashes { + _, err := s.InsertHashNode(nodeHash.path, nodeHash.hash.Big()) + if err != nil { + return nil, err + } + + _, err = s.Db.GetLastRoot() + if err != nil { + return nil, err + } + } + + for addr, balance := range balanceMap { + _, err := s.SetAccountBalance(addr, balance) + if err != nil { + return nil, err + } + } + + for addr, nonce := range nonceMap { + _, err := s.SetAccountNonce(addr, nonce) + if err != nil { + return nil, err + } + } + + for addr, code := range contractMap { + err := s.SetContractBytecode(addr, code) + if err != nil { + return nil, err + } + } + + for addr, storage := range storageMap { + _, err := s.SetContractStorage(addr, storage, nil) + if err != nil { + fmt.Println("error : unable to set contract storage", err) + } + } + + return s, nil +} + +func intArrayToString(a []int) string { + s := "" + for _, v := range a { + s += fmt.Sprintf("%d", v) + } + return s +} diff --git a/smt/pkg/smt/witness_test.go b/smt/pkg/smt/witness_test.go index ca0dcdf2f91..7078fec5ed5 100644 --- a/smt/pkg/smt/witness_test.go +++ b/smt/pkg/smt/witness_test.go @@ -16,6 +16,7 @@ import ( "github.com/ledgerwatch/erigon/smt/pkg/smt" "github.com/ledgerwatch/erigon/smt/pkg/utils" "github.com/ledgerwatch/erigon/turbo/trie" + "github.com/stretchr/testify/require" ) func prepareSMT(t *testing.T) (*smt.SMT, *trie.RetainList) { @@ -158,3 +159,95 @@ func TestSMTWitnessRetainListEmptyVal(t *testing.T) { t.Errorf("witness contains unexpected operator") } } + +// TestWitnessToSMT tests that the SMT built from a witness matches the original SMT +func TestWitnessToSMT(t *testing.T) { + smtTrie, rl := prepareSMT(t) + + witness, err := smt.BuildWitness(smtTrie, rl, context.Background()) + if err != nil { + t.Errorf("error building witness: %v", err) + } + + newSMT, err := smt.BuildSMTfromWitness(witness) + if err != nil { + t.Errorf("error building SMT from witness: %v", err) + } + + root, err := newSMT.Db.GetLastRoot() + if err != nil { + t.Errorf("error getting last root: %v", err) + } + + // newSMT.Traverse(context.Background(), root, func(prefix []byte, k utils.NodeKey, v utils.NodeValue12) (bool, error) { + // fmt.Printf("[After] path: %v, hash: %x\n", prefix, libcommon.BigToHash(k.ToBigInt())) + // return true, nil + // }) + + expectedRoot, err := smtTrie.Db.GetLastRoot() + require.NoError(t, err, "error getting last root") + + // assert that the roots are the same + if expectedRoot.Cmp(root) != 0 { + t.Errorf(fmt.Sprintf("SMT root mismatch, expected %x, got %x", expectedRoot.Bytes(), root.Bytes())) + } +} + +// TestWitnessToSMTStateReader tests that the SMT built from a witness matches the state +func TestWitnessToSMTStateReader(t *testing.T) { + smtTrie, rl := prepareSMT(t) + + sKey := libcommon.HexToHash("0x5") + + expectedRoot, err := smtTrie.Db.GetLastRoot() + if err != nil { + t.Errorf("error getting last root: %v", err) + } + + witness, err := smt.BuildWitness(smtTrie, rl, context.Background()) + if err != nil { + t.Errorf("error building witness: %v", err) + } + + newSMT, err := smt.BuildSMTfromWitness(witness) + if err != nil { + t.Errorf("error building SMT from witness: %v", err) + } + root, err := newSMT.Db.GetLastRoot() + if err != nil { + t.Errorf("error building SMT from witness: %v", err) + } + + if expectedRoot.Cmp(root) != 0 { + t.Errorf(fmt.Sprintf("SMT root mismatch, expected %x, got %x", expectedRoot.Bytes(), root.Bytes())) + } + + contract := libcommon.HexToAddress("0x71dd1027069078091B3ca48093B00E4735B20624") + + expectedAcc, _ := smtTrie.ReadAccountData(contract) + newAcc, _ := newSMT.ReadAccountData(contract) + + expectedAccCode, _ := smtTrie.ReadAccountCode(contract, 0, expectedAcc.CodeHash) + newAccCode, _ := newSMT.ReadAccountCode(contract, 0, newAcc.CodeHash) + expectedAccCodeSize, _ := smtTrie.ReadAccountCodeSize(contract, 0, expectedAcc.CodeHash) + newAccCodeSize, _ := newSMT.ReadAccountCodeSize(contract, 0, newAcc.CodeHash) + expectedStorageValue, _ := smtTrie.ReadAccountStorage(contract, 0, &sKey) + newStorageValue, _ := newSMT.ReadAccountStorage(contract, 0, &sKey) + // assert that the account data is the same + require.Equal(t, expectedAcc, newAcc) + + require.Equal(t, expectedAccCode, newAccCode) + // TODO: @Stefan-Ethernal Check and remove + // // assert that the account code is the same + // if !bytes.Equal(expectedAccCode, newAccCode) { + // t.Error("Account Code Mismatch") + // } + + // assert that the account code size is the same + require.Equal(t, expectedAccCodeSize, newAccCodeSize) + + // assert that the storage value is the same + if !bytes.Equal(expectedStorageValue, newStorageValue) { + t.Error("Storage Value Mismatch") + } +}