From df6709f065b59a299e893b30db714b82b8a0e41e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Negovanovi=C4=87?= <93934272+Stefan-Ethernal@users.noreply.github.com> Date: Mon, 4 Nov 2024 20:36:32 +0100 Subject: [PATCH] feat: Decode witness to SMT (#1363) * feat: Decode witness to SMT * chore: warning fixes and simplifications * test: use require in witness unit tests * feat: simplifications in SMT state reader * fix: address comment * test: use requires * Allocate array in getValueInBytes --- smt/pkg/db/mdbx.go | 14 +++ smt/pkg/smt/entity_storage.go | 69 +++++++----- smt/pkg/smt/smt.go | 94 +++++++++++++++- smt/pkg/smt/smt_state_reader.go | 193 ++++++++++++++++++++++++++++++++ smt/pkg/smt/witness.go | 169 +++++++++++++++++++++++++++- smt/pkg/smt/witness_test.go | 160 +++++++++++++++++++------- 6 files changed, 628 insertions(+), 71 deletions(-) create mode 100644 smt/pkg/smt/smt_state_reader.go diff --git a/smt/pkg/db/mdbx.go b/smt/pkg/db/mdbx.go index 18351d0fe6c..adca963eaac 100644 --- a/smt/pkg/db/mdbx.go +++ b/smt/pkg/db/mdbx.go @@ -2,6 +2,7 @@ package db import ( "context" + "encoding/hex" "math/big" "fmt" @@ -304,6 +305,19 @@ func (m *EriRoDb) GetCode(codeHash []byte) ([]byte, error) { return data, nil } +func (m *EriDb) AddCode(code []byte) error { + codeHash := utils.HashContractBytecode(hex.EncodeToString(code)) + + codeHashBytes, err := hex.DecodeString(strings.TrimPrefix(codeHash, "0x")) + if err != nil { + return err + } + + codeHashBytes = utils.ResizeHashTo32BytesByPrefixingWithZeroes(codeHashBytes) + + return m.tx.Put(kv.Code, codeHashBytes, code) +} + func (m *EriRoDb) PrintDb() { err := m.kvTxRo.ForEach(TableSmt, []byte{}, func(k, v []byte) error { println(string(k), string(v)) diff --git a/smt/pkg/smt/entity_storage.go b/smt/pkg/smt/entity_storage.go index e33a6d06357..261b27103cd 100644 --- a/smt/pkg/smt/entity_storage.go +++ b/smt/pkg/smt/entity_storage.go @@ -14,30 +14,55 @@ import ( "github.com/ledgerwatch/erigon/smt/pkg/utils" ) +// SetAccountState sets the balance and nonce of an account func (s *SMT) SetAccountState(ethAddr string, balance, nonce *big.Int) (*big.Int, error) { + _, err := s.SetAccountBalance(ethAddr, balance) + if err != nil { + return nil, err + } + + auxOut, err := s.SetAccountNonce(ethAddr, nonce) + if err != nil { + return nil, err + } + + return auxOut, nil +} + +// SetAccountBalance sets the balance of an account +func (s *SMT) SetAccountBalance(ethAddr string, balance *big.Int) (*big.Int, error) { keyBalance := utils.KeyEthAddrBalance(ethAddr) - keyNonce := utils.KeyEthAddrNonce(ethAddr) - if _, err := s.InsertKA(keyBalance, balance); err != nil { + response, err := s.InsertKA(keyBalance, balance) + if err != nil { return nil, err } ks := utils.EncodeKeySource(utils.KEY_BALANCE, utils.ConvertHexToAddress(ethAddr), common.Hash{}) - if err := s.Db.InsertKeySource(keyBalance, ks); err != nil { + err = s.Db.InsertKeySource(keyBalance, ks) + if err != nil { return nil, err } - auxRes, err := s.InsertKA(keyNonce, nonce) + return response.NewRootScalar.ToBigInt(), err +} + +// SetAccountNonce sets the nonce of an account +func (s *SMT) SetAccountNonce(ethAddr string, nonce *big.Int) (*big.Int, error) { + keyNonce := utils.KeyEthAddrNonce(ethAddr) + + response, err := s.InsertKA(keyNonce, nonce) if err != nil { return nil, err } - ks = utils.EncodeKeySource(utils.KEY_NONCE, utils.ConvertHexToAddress(ethAddr), common.Hash{}) - if err := s.Db.InsertKeySource(keyNonce, ks); err != nil { + ks := utils.EncodeKeySource(utils.KEY_NONCE, utils.ConvertHexToAddress(ethAddr), common.Hash{}) + err = s.Db.InsertKeySource(keyNonce, ks) + if err != nil { return nil, err } - return auxRes.NewRootScalar.ToBigInt(), nil + return response.NewRootScalar.ToBigInt(), nil } func (s *SMT) SetAccountStorage(addr libcommon.Address, acc *accounts.Account) error { @@ -80,13 +105,7 @@ func (s *SMT) SetContractBytecode(ethAddr string, bytecode string) error { ks = utils.EncodeKeySource(utils.SC_LENGTH, utils.ConvertHexToAddress(ethAddr), common.Hash{}) - err = s.Db.InsertKeySource(keyContractLength, ks) - - if err != nil { - return err - } - - return err + return s.Db.InsertKeySource(keyContractLength, ks) } func (s *SMT) SetContractStorage(ethAddr string, storage map[string]string, progressChan chan uint64) (*big.Int, error) { @@ -203,7 +222,7 @@ func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[l for addr, acc := range accChanges { select { case <-ctx.Done(): - return nil, nil, fmt.Errorf(fmt.Sprintf("[%s] Context done", logPrefix)) + return nil, nil, fmt.Errorf("[%s] Context done", logPrefix) default: } ethAddr := addr.String() @@ -250,7 +269,7 @@ func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[l for addr, code := range codeChanges { select { case <-ctx.Done(): - return nil, nil, fmt.Errorf(fmt.Sprintf("[%s] Context done", logPrefix)) + return nil, nil, fmt.Errorf("[%s] Context done", logPrefix) default: } @@ -295,7 +314,7 @@ func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[l for addr, storage := range storageChanges { select { case <-ctx.Done(): - return nil, nil, fmt.Errorf(fmt.Sprintf("[%s] Context done", logPrefix)) + return nil, nil, fmt.Errorf("[%s] Context done", logPrefix) default: } ethAddr := addr.String() @@ -304,7 +323,7 @@ func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[l for k, v := range storage { keyStoragePosition := utils.KeyContractStorage(ethAddrBigIngArray, k) - valueBigInt := convertStrintToBigInt(v) + valueBigInt := convertStringToBigInt(v) keysBatchStorage = append(keysBatchStorage, &keyStoragePosition) if valuesBatchStorage, isDelete, err = appendToValuesBatchStorageBigInt(valuesBatchStorage, valueBigInt); err != nil { return nil, nil, err @@ -341,7 +360,7 @@ func (s *SMT) DeleteKeySource(nodeKey *utils.NodeKey) error { } func calcHashVal(v string) (*utils.NodeValue8, [4]uint64, error) { - val := convertStrintToBigInt(v) + val := convertStringToBigInt(v) x := utils.ScalarToArrayBig(val) value, err := utils.NodeValue8FromBigIntArray(x) @@ -354,10 +373,10 @@ func calcHashVal(v string) (*utils.NodeValue8, [4]uint64, error) { return value, h, nil } -func convertStrintToBigInt(v string) *big.Int { +func convertStringToBigInt(v string) *big.Int { base := 10 if strings.HasPrefix(v, "0x") { - v = v[2:] + v = strings.TrimPrefix(v, "0x") base = 16 } @@ -374,14 +393,8 @@ func appendToValuesBatchStorageBigInt(valuesBatchStorage []*utils.NodeValue8, va } func convertBytecodeToBigInt(bytecode string) (*big.Int, int, error) { - var parsedBytecode string bi := utils.HashContractBytecodeBigInt(bytecode) - - if strings.HasPrefix(bytecode, "0x") { - parsedBytecode = bytecode[2:] - } else { - parsedBytecode = bytecode - } + parsedBytecode := strings.TrimPrefix(bytecode, "0x") if len(parsedBytecode)%2 != 0 { parsedBytecode = "0" + parsedBytecode diff --git a/smt/pkg/smt/smt.go b/smt/pkg/smt/smt.go index 1c2a8a386e4..50d0221916d 100644 --- a/smt/pkg/smt/smt.go +++ b/smt/pkg/smt/smt.go @@ -23,6 +23,7 @@ type DB interface { InsertKeySource(key utils.NodeKey, value []byte) error DeleteKeySource(key utils.NodeKey) error InsertHashKey(key utils.NodeKey, value utils.NodeKey) error + AddCode(code []byte) error DeleteHashKey(key utils.NodeKey) error Delete(string) error DeleteByNodeKey(key utils.NodeKey) error @@ -297,7 +298,9 @@ func (s *SMT) insert(k utils.NodeKey, v utils.NodeValue8, newValH [4]uint64, old if err != nil { return nil, err } - s.Db.InsertHashKey(newLeafHash, k) + if err := s.Db.InsertHashKey(newLeafHash, k); err != nil { + return nil, err + } if level >= 0 { for j := 0; j < 4; j++ { siblings[level][keys[level]*4+j] = new(big.Int).SetUint64(newLeafHash[j]) @@ -649,7 +652,7 @@ func (s *SMT) updateDepth(newDepth int) { newDepthAsByte := byte(newDepth & 0xFF) if oldDepth < newDepthAsByte { - s.Db.SetDepth(newDepthAsByte) + _ = s.Db.SetDepth(newDepthAsByte) } } @@ -728,3 +731,90 @@ func (s *RoSMT) traverseAndMark(ctx context.Context, node *big.Int, visited Visi return true, nil }) } + +// InsertHashNode inserts a hash node into the SMT. The SMT should not contain any other leaf nodes with the same path prefix. Otherwise, the new root hash will be incorrect. +// TODO: Support insertion of hash nodes even if there are leaf nodes with the same path prefix in SMT. +func (s *SMT) InsertHashNode(path []int, hash *big.Int) (*big.Int, error) { + s.clearUpMutex.Lock() + defer s.clearUpMutex.Unlock() + + or, err := s.getLastRoot() + if err != nil { + return nil, err + } + + h := utils.ScalarToArray(hash) + + var nodeHash [4]uint64 + copy(nodeHash[:], h[:4]) + + lastRoot, err := s.insertHashNode(path, nodeHash, or) + if err != nil { + return nil, err + } + + if err = s.setLastRoot(lastRoot); err != nil { + return nil, err + } + + return lastRoot.ToBigInt(), nil +} + +func (s *SMT) insertHashNode(path []int, hash [4]uint64, root utils.NodeKey) (utils.NodeKey, error) { + if len(path) == 0 { + newValHBig := utils.ArrayToScalar(hash[:]) + v := utils.ScalarToNodeValue8(newValHBig) + + err := s.hashSave(v.ToUintArray(), utils.LeafCapacity, hash) + if err != nil { + return utils.NodeKey{}, err + } + + return hash, nil + } + + rootVal := utils.NodeValue12{} + + if !root.IsZero() { + v, err := s.Db.Get(root) + if err != nil { + return utils.NodeKey{}, err + } + + rootVal = v + } + + childIndex := path[0] + + childOldRoot := rootVal[childIndex*4 : childIndex*4+4] + + childNewRoot, err := s.insertHashNode(path[1:], hash, utils.NodeKeyFromBigIntArray(childOldRoot)) + + if err != nil { + return utils.NodeKey{}, err + } + + var newIn [8]uint64 + + emptyRootVal := utils.NodeValue12{} + + if childIndex == 0 { + var sibling [4]uint64 + if rootVal == emptyRootVal { + sibling = [4]uint64{0, 0, 0, 0} + } else { + sibling = *rootVal.Get4to8() + } + newIn = utils.ConcatArrays4(childNewRoot, sibling) + } else { + var sibling [4]uint64 + if rootVal == emptyRootVal { + sibling = [4]uint64{0, 0, 0, 0} + } else { + sibling = *rootVal.Get0to4() + } + newIn = utils.ConcatArrays4(sibling, childNewRoot) + } + + return s.hashcalcAndSave(newIn, utils.BranchCapacity) +} diff --git a/smt/pkg/smt/smt_state_reader.go b/smt/pkg/smt/smt_state_reader.go new file mode 100644 index 00000000000..4e1b4849497 --- /dev/null +++ b/smt/pkg/smt/smt_state_reader.go @@ -0,0 +1,193 @@ +package smt + +import ( + "bytes" + "context" + "errors" + "math/big" + + "github.com/holiman/uint256" + libcommon "github.com/ledgerwatch/erigon-lib/common" + "github.com/ledgerwatch/erigon/core/state" + "github.com/ledgerwatch/erigon/core/types/accounts" + "github.com/ledgerwatch/erigon/smt/pkg/utils" + "github.com/ledgerwatch/erigon/zkevm/log" +) + +var _ state.StateReader = (*SMT)(nil) + +// ReadAccountData reads account data from the SMT +func (s *SMT) ReadAccountData(address libcommon.Address) (*accounts.Account, error) { + balance, err := s.GetAccountBalance(address) + if err != nil { + return nil, err + } + + nonce, err := s.GetAccountNonce(address) + if err != nil { + return nil, err + } + + codeHash, err := s.GetAccountCodeHash(address) + if err != nil { + return nil, err + } + + account := &accounts.Account{ + Balance: *balance, + Nonce: nonce.Uint64(), + CodeHash: codeHash, + Root: libcommon.Hash{}, + } + + return account, nil +} + +// ReadAccountStorage reads account storage from the SMT (not implemented for SMT) +func (s *SMT) ReadAccountStorage(address libcommon.Address, incarnation uint64, key *libcommon.Hash) ([]byte, error) { + value, err := s.getValue(0, address, key) + if err != nil { + return []byte{}, err + } + + return value, nil +} + +// ReadAccountCode reads account code from the SMT +func (s *SMT) ReadAccountCode(address libcommon.Address, incarnation uint64, codeHash libcommon.Hash) ([]byte, error) { + code, err := s.Db.GetCode(codeHash.Bytes()) + if err != nil { + return []byte{}, err + } + + return code, nil +} + +// ReadAccountCodeSize reads account code size from the SMT +func (s *SMT) ReadAccountCodeSize(address libcommon.Address, _ uint64, _ libcommon.Hash) (int, error) { + valueInBytes, err := s.getValue(utils.SC_LENGTH, address, nil) + if err != nil { + return 0, err + } + + sizeBig := big.NewInt(0).SetBytes(valueInBytes) + + if !sizeBig.IsInt64() { + err = errors.New("code size value is too large to fit into an int") + return 0, err + } + + sizeInt64 := sizeBig.Int64() + if sizeInt64 > int64(^uint(0)>>1) { + err = errors.New("code size value overflows int") + log.Error("failed to get account code size", "error", err) + return 0, err + } + + return int(sizeInt64), nil +} + +// ReadAccountIncarnation reads account incarnation from the SMT (not implemented for SMT) +func (s *SMT) ReadAccountIncarnation(_ libcommon.Address) (uint64, error) { + return 0, errors.New("ReadAccountIncarnation not implemented for SMT") +} + +// GetAccountBalance returns the balance of an account from the SMT +func (s *SMT) GetAccountBalance(address libcommon.Address) (*uint256.Int, error) { + valueInBytes, err := s.getValue(utils.KEY_BALANCE, address, nil) + if err != nil { + log.Error("failed to get balance", "error", err) + return nil, err + } + + balance := uint256.NewInt(0).SetBytes(valueInBytes) + + return balance, nil +} + +// GetAccountNonce returns the nonce of an account from the SMT +func (s *SMT) GetAccountNonce(address libcommon.Address) (*uint256.Int, error) { + valueInBytes, err := s.getValue(utils.KEY_NONCE, address, nil) + if err != nil { + log.Error("failed to get nonce", "error", err) + return nil, err + } + + nonce := uint256.NewInt(0).SetBytes(valueInBytes) + + return nonce, nil +} + +// GetAccountCodeHash returns the code hash of an account from the SMT +func (s *SMT) GetAccountCodeHash(address libcommon.Address) (libcommon.Hash, error) { + valueInBytes, err := s.getValue(utils.SC_CODE, address, nil) + if err != nil { + log.Error("failed to get code hash", "error", err) + return libcommon.Hash{}, err + } + + codeHash := libcommon.Hash{} + codeHash.SetBytes(valueInBytes) + + return codeHash, nil +} + +// getValue returns the value of a key from SMT by traversing the SMT +func (s *SMT) getValue(key int, address libcommon.Address, storageKey *libcommon.Hash) ([]byte, error) { + var kn utils.NodeKey + + if storageKey == nil { + kn = utils.Key(address.String(), key) + } else { + a := utils.ConvertHexToBigInt(address.String()) + add := utils.ScalarToArrayBig(a) + + kn = utils.KeyContractStorage(add, storageKey.String()) + } + + return s.getValueInBytes(kn) +} + +// getValueInBytes returns the value of a key from SMT in bytes by traversing the SMT +func (s *SMT) getValueInBytes(nodeKey utils.NodeKey) ([]byte, error) { + value := []byte{} + + keyPath := nodeKey.GetPath() + + keyPathBytes := make([]byte, len(keyPath)) + for i, k := range keyPath { + keyPathBytes[i] = byte(k) + } + + action := func(prefix []byte, _ utils.NodeKey, v utils.NodeValue12) (bool, error) { + if !bytes.HasPrefix(keyPathBytes, prefix) { + return false, nil + } + + if v.IsFinalNode() { + valHash := v.Get4to8() + v, err := s.Db.Get(*valHash) + if err != nil { + return false, err + } + vInBytes := utils.ArrayBigToScalar(utils.BigIntArrayFromNodeValue8(v.GetNodeValue8())).Bytes() + + value = vInBytes + return false, nil + } + + return true, nil + } + + root, err := s.Db.GetLastRoot() + if err != nil { + return nil, err + } + + err = s.Traverse(context.Background(), root, action) + if err != nil { + return nil, err + } + + return value, nil +} diff --git a/smt/pkg/smt/witness.go b/smt/pkg/smt/witness.go index ce64d08107e..5fc7d64e336 100644 --- a/smt/pkg/smt/witness.go +++ b/smt/pkg/smt/witness.go @@ -2,12 +2,16 @@ package smt import ( "context" + "fmt" + "math/big" libcommon "github.com/ledgerwatch/erigon-lib/common" "github.com/ledgerwatch/erigon/smt/pkg/utils" "github.com/ledgerwatch/erigon/turbo/trie" + "github.com/status-im/keycard-go/hexutils" ) +// BuildWitness creates a witness from the SMT func BuildWitness(s *SMT, rd trie.RetainDecider, ctx context.Context) (*trie.Witness, error) { operands := make([]trie.WitnessOperator, 0) @@ -33,7 +37,7 @@ func BuildWitness(s *SMT, rd trie.RetainDecider, ctx context.Context) (*trie.Wit This algorithm adds a little bit more nodes to the witness but it ensures that all requiring nodes are included. */ - retain := true + var retain bool prefixLen := len(prefix) if prefixLen > 0 { @@ -112,3 +116,166 @@ func BuildWitness(s *SMT, rd trie.RetainDecider, ctx context.Context) (*trie.Wit return trie.NewWitness(operands), err } + +// BuildSMTfromWitness builds SMT from witness +func BuildSMTfromWitness(w *trie.Witness) (*SMT, error) { + // using memdb + s := NewSMT(nil, false) + + balanceMap := make(map[string]*big.Int) + nonceMap := make(map[string]*big.Int) + contractMap := make(map[string]string) + storageMap := make(map[string]map[string]string) + + path := make([]int, 0) + + firstNode := true + NodeChildCountMap := make(map[string]uint32) + NodesBranchValueMap := make(map[string]uint32) + + type nodeHash struct { + path []int + hash libcommon.Hash + } + + nodeHashes := make([]nodeHash, 0) + + for i, operator := range w.Operators { + switch op := operator.(type) { + case *trie.OperatorSMTLeafValue: + valScaler := big.NewInt(0).SetBytes(op.Value) + addr := libcommon.BytesToAddress(op.Address) + + switch op.NodeType { + case utils.KEY_BALANCE: + balanceMap[addr.String()] = valScaler + + case utils.KEY_NONCE: + nonceMap[addr.String()] = valScaler + + case utils.SC_STORAGE: + if _, ok := storageMap[addr.String()]; !ok { + storageMap[addr.String()] = make(map[string]string) + } + + stKey := hexutils.BytesToHex(op.StorageKey) + if len(stKey) > 0 { + stKey = fmt.Sprintf("0x%s", stKey) + } + + storageMap[addr.String()][stKey] = valScaler.String() + } + + path = path[:len(path)-1] + NodeChildCountMap[intArrayToString(path)] += 1 + + for len(path) != 0 && NodeChildCountMap[intArrayToString(path)] == NodesBranchValueMap[intArrayToString(path)] { + path = path[:len(path)-1] + } + if NodeChildCountMap[intArrayToString(path)] < NodesBranchValueMap[intArrayToString(path)] { + path = append(path, 1) + } + + case *trie.OperatorCode: + addr := libcommon.BytesToAddress(w.Operators[i+1].(*trie.OperatorSMTLeafValue).Address) + + code := hexutils.BytesToHex(op.Code) + if len(code) > 0 { + if err := s.Db.AddCode(hexutils.HexToBytes(code)); err != nil { + return nil, err + } + code = fmt.Sprintf("0x%s", code) + } + + contractMap[addr.String()] = code + + case *trie.OperatorBranch: + if firstNode { + firstNode = false + } else { + NodeChildCountMap[intArrayToString(path[:len(path)-1])] += 1 + } + + switch op.Mask { + case 1: + NodesBranchValueMap[intArrayToString(path)] = 1 + path = append(path, 0) + case 2: + NodesBranchValueMap[intArrayToString(path)] = 1 + path = append(path, 1) + case 3: + NodesBranchValueMap[intArrayToString(path)] = 2 + path = append(path, 0) + } + + case *trie.OperatorHash: + pathCopy := make([]int, len(path)) + copy(pathCopy, path) + nodeHashes = append(nodeHashes, nodeHash{path: pathCopy, hash: op.Hash}) + + path = path[:len(path)-1] + NodeChildCountMap[intArrayToString(path)] += 1 + + for len(path) != 0 && NodeChildCountMap[intArrayToString(path)] == NodesBranchValueMap[intArrayToString(path)] { + path = path[:len(path)-1] + } + if NodeChildCountMap[intArrayToString(path)] < NodesBranchValueMap[intArrayToString(path)] { + path = append(path, 1) + } + + default: + // Unsupported operator type + return nil, fmt.Errorf("unsupported operator type: %T", op) + } + } + + for _, nodeHash := range nodeHashes { + _, err := s.InsertHashNode(nodeHash.path, nodeHash.hash.Big()) + if err != nil { + return nil, err + } + + _, err = s.Db.GetLastRoot() + if err != nil { + return nil, err + } + } + + for addr, balance := range balanceMap { + _, err := s.SetAccountBalance(addr, balance) + if err != nil { + return nil, err + } + } + + for addr, nonce := range nonceMap { + _, err := s.SetAccountNonce(addr, nonce) + if err != nil { + return nil, err + } + } + + for addr, code := range contractMap { + err := s.SetContractBytecode(addr, code) + if err != nil { + return nil, err + } + } + + for addr, storage := range storageMap { + _, err := s.SetContractStorage(addr, storage, nil) + if err != nil { + fmt.Println("error : unable to set contract storage", err) + } + } + + return s, nil +} + +func intArrayToString(a []int) string { + s := "" + for _, v := range a { + s += fmt.Sprintf("%d", v) + } + return s +} diff --git a/smt/pkg/smt/witness_test.go b/smt/pkg/smt/witness_test.go index ca0dcdf2f91..b055f375e78 100644 --- a/smt/pkg/smt/witness_test.go +++ b/smt/pkg/smt/witness_test.go @@ -16,9 +16,12 @@ import ( "github.com/ledgerwatch/erigon/smt/pkg/smt" "github.com/ledgerwatch/erigon/smt/pkg/utils" "github.com/ledgerwatch/erigon/turbo/trie" + "github.com/stretchr/testify/require" ) func prepareSMT(t *testing.T) (*smt.SMT, *trie.RetainList) { + t.Helper() + contract := libcommon.HexToAddress("0x71dd1027069078091B3ca48093B00E4735B20624") balance := uint256.NewInt(1000000000) sKey := libcommon.HexToHash("0x5") @@ -43,46 +46,46 @@ func prepareSMT(t *testing.T) (*smt.SMT, *trie.RetainList) { intraBlockState.AddBalance(contract, balance) intraBlockState.SetState(contract, &sKey, *sVal) - if err := intraBlockState.FinalizeTx(&chain.Rules{}, tds.TrieStateWriter()); err != nil { - t.Errorf("error finalising 1st tx: %v", err) - } - if err := intraBlockState.CommitBlock(&chain.Rules{}, w); err != nil { - t.Errorf("error committing block: %v", err) - } + err := intraBlockState.FinalizeTx(&chain.Rules{}, tds.TrieStateWriter()) + require.NoError(t, err, "error finalising 1st tx") - rl, err := tds.ResolveSMTRetainList() + err = intraBlockState.CommitBlock(&chain.Rules{}, w) + require.NoError(t, err, "error committing block") - if err != nil { - t.Errorf("error resolving state trie: %v", err) - } + rl, err := tds.ResolveSMTRetainList() + require.NoError(t, err, "error resolving state trie") memdb := db.NewMemDb() smtTrie := smt.NewSMT(memdb, false) - smtTrie.SetAccountState(contract.String(), balance.ToBig(), uint256.NewInt(1).ToBig()) - smtTrie.SetContractBytecode(contract.String(), hex.EncodeToString(code)) - err = memdb.AddCode(code) + _, err = smtTrie.SetAccountState(contract.String(), balance.ToBig(), uint256.NewInt(1).ToBig()) + require.NoError(t, err) - if err != nil { - t.Errorf("error adding code to memdb: %v", err) - } + err = smtTrie.SetContractBytecode(contract.String(), hex.EncodeToString(code)) + require.NoError(t, err) + + err = memdb.AddCode(code) + require.NoError(t, err, "error adding code to memdb") storage := make(map[string]string, 0) for i := 0; i < 100; i++ { - k := libcommon.HexToHash(fmt.Sprintf("0x%d", i)) - storage[k.String()] = k.String() + k := libcommon.HexToHash(fmt.Sprintf("0x%d", i)).String() + storage[k] = k } storage[sKey.String()] = sVal.String() - smtTrie.SetContractStorage(contract.String(), storage, nil) + _, err = smtTrie.SetContractStorage(contract.String(), storage, nil) + require.NoError(t, err) return smtTrie, rl } func findNode(t *testing.T, w *trie.Witness, addr libcommon.Address, storageKey libcommon.Hash, nodeType int) []byte { + t.Helper() + for _, operator := range w.Operators { switch op := operator.(type) { case *trie.OperatorSMTLeafValue: @@ -109,23 +112,19 @@ func TestSMTWitnessRetainList(t *testing.T) { sVal := uint256.NewInt(0xdeadbeef) witness, err := smt.BuildWitness(smtTrie, rl, context.Background()) - - if err != nil { - t.Errorf("error building witness: %v", err) - } + require.NoError(t, err, "error building witness") foundCode := findNode(t, witness, contract, libcommon.Hash{}, utils.SC_CODE) foundBalance := findNode(t, witness, contract, libcommon.Hash{}, utils.KEY_BALANCE) foundNonce := findNode(t, witness, contract, libcommon.Hash{}, utils.KEY_NONCE) foundStorage := findNode(t, witness, contract, sKey, utils.SC_STORAGE) - if foundCode == nil || foundBalance == nil || foundNonce == nil || foundStorage == nil { - t.Errorf("witness does not contain all expected operators") - } + require.NotNil(t, foundCode) + require.NotNil(t, foundBalance) + require.NotNil(t, foundNonce) + require.NotNil(t, foundStorage) - if !bytes.Equal(foundStorage, sVal.Bytes()) { - t.Errorf("witness contains unexpected storage value") - } + require.Equal(t, foundStorage, sVal.Bytes(), "witness contains unexpected storage value") } func TestSMTWitnessRetainListEmptyVal(t *testing.T) { @@ -136,25 +135,106 @@ func TestSMTWitnessRetainListEmptyVal(t *testing.T) { sKey := libcommon.HexToHash("0x5") // Set nonce to 0 - smtTrie.SetAccountState(contract.String(), balance.ToBig(), uint256.NewInt(0).ToBig()) + _, err := smtTrie.SetAccountState(contract.String(), balance.ToBig(), uint256.NewInt(0).ToBig()) + require.NoError(t, err) witness, err := smt.BuildWitness(smtTrie, rl, context.Background()) - - if err != nil { - t.Errorf("error building witness: %v", err) - } + require.NoError(t, err, "error building witness") foundCode := findNode(t, witness, contract, libcommon.Hash{}, utils.SC_CODE) foundBalance := findNode(t, witness, contract, libcommon.Hash{}, utils.KEY_BALANCE) foundNonce := findNode(t, witness, contract, libcommon.Hash{}, utils.KEY_NONCE) foundStorage := findNode(t, witness, contract, sKey, utils.SC_STORAGE) - if foundCode == nil || foundBalance == nil || foundStorage == nil { - t.Errorf("witness does not contain all expected operators") - } + // Code, balance and storage should be present in the witness + require.NotNil(t, foundCode) + require.NotNil(t, foundBalance) + require.NotNil(t, foundStorage) // Nonce should not be in witness - if foundNonce != nil { - t.Errorf("witness contains unexpected operator") - } + require.Nil(t, foundNonce, "witness contains unexpected operator") +} + +// TestWitnessToSMT tests that the SMT built from a witness matches the original SMT +func TestWitnessToSMT(t *testing.T) { + smtTrie, rl := prepareSMT(t) + + witness, err := smt.BuildWitness(smtTrie, rl, context.Background()) + require.NoError(t, err, "error building witness") + + newSMT, err := smt.BuildSMTfromWitness(witness) + require.NoError(t, err, "error building SMT from witness") + + root, err := newSMT.Db.GetLastRoot() + require.NoError(t, err, "error getting last root from db") + + // newSMT.Traverse(context.Background(), root, func(prefix []byte, k utils.NodeKey, v utils.NodeValue12) (bool, error) { + // fmt.Printf("[After] path: %v, hash: %x\n", prefix, libcommon.BigToHash(k.ToBigInt())) + // return true, nil + // }) + + expectedRoot, err := smtTrie.Db.GetLastRoot() + require.NoError(t, err, "error getting last root") + + // assert that the roots are the same + require.Equal(t, expectedRoot, root, "SMT root mismatch") +} + +// TestWitnessToSMTStateReader tests that the SMT built from a witness matches the state +func TestWitnessToSMTStateReader(t *testing.T) { + smtTrie, rl := prepareSMT(t) + + sKey := libcommon.HexToHash("0x5") + + expectedRoot, err := smtTrie.Db.GetLastRoot() + require.NoError(t, err, "error getting last root") + + witness, err := smt.BuildWitness(smtTrie, rl, context.Background()) + require.NoError(t, err, "error building witness") + + newSMT, err := smt.BuildSMTfromWitness(witness) + require.NoError(t, err, "error building SMT from witness") + + root, err := newSMT.Db.GetLastRoot() + require.NoError(t, err, "error getting the last root from db") + + require.Equal(t, expectedRoot, root, "SMT root mismatch") + + contract := libcommon.HexToAddress("0x71dd1027069078091B3ca48093B00E4735B20624") + + expectedAcc, err := smtTrie.ReadAccountData(contract) + require.NoError(t, err) + + newAcc, err := newSMT.ReadAccountData(contract) + require.NoError(t, err) + + expectedAccCode, err := smtTrie.ReadAccountCode(contract, 0, expectedAcc.CodeHash) + require.NoError(t, err) + + newAccCode, err := newSMT.ReadAccountCode(contract, 0, newAcc.CodeHash) + require.NoError(t, err) + + expectedAccCodeSize, err := smtTrie.ReadAccountCodeSize(contract, 0, expectedAcc.CodeHash) + require.NoError(t, err) + + newAccCodeSize, err := newSMT.ReadAccountCodeSize(contract, 0, newAcc.CodeHash) + require.NoError(t, err) + + expectedStorageValue, err := smtTrie.ReadAccountStorage(contract, 0, &sKey) + require.NoError(t, err) + + newStorageValue, err := newSMT.ReadAccountStorage(contract, 0, &sKey) + require.NoError(t, err) + + // assert that the account data is the same + require.Equal(t, expectedAcc, newAcc) + + // assert that account code is the same + require.Equal(t, expectedAccCode, newAccCode) + + // assert that the account code size is the same + require.Equal(t, expectedAccCodeSize, newAccCodeSize) + + // assert that the storage value is the same + require.Equal(t, expectedStorageValue, newStorageValue) }