Skip to content
This repository has been archived by the owner on Jul 5, 2024. It is now read-only.

Commit

Permalink
Feat/#1752 fix witness generation - adding testing (#1784)
Browse files Browse the repository at this point in the history
### Description

closed #1752 

- Adding different numbers of txs to test different trie status (e.g.
one leaf only, one ext. , one branch and one leaf ...etc)
- Fixing `GetProof` (if input is an ext node, it always assign first
child of the root ext. node, `st.children[0]`)
- Fixing `getNodeFromBranchRLP` (if input is an leaf, it could throw an
exception)
- Refactoring `getNodeFromBranchRLP`

### Issue Link

#1752 

### Type of change

- [x] Bug fix (non-breaking change which fixes an issue)
  • Loading branch information
KimiWu123 authored Mar 7, 2024
1 parent 41e3408 commit b82ea41
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 99 deletions.
143 changes: 91 additions & 52 deletions geth-utils/gethutil/mpt/trie/stacktrie.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,16 +243,17 @@ func (st *StackTrie) insert(key, value []byte) {
break
}
}

// Add new child
if st.children[idx] == nil {
st.children[idx] = stackTrieFromPool(st.db)
st.children[idx].keyOffset = st.keyOffset + 1
st.children[idx] = newLeaf(st.keyOffset+1, key, value, st.db)
} else {
st.children[idx].insert(key, value)
}
st.children[idx].insert(key, value)

case extNode: /* Ext */
// Compare both key chunks and see where they differ
diffidx := st.getDiffIndex(key)

// Check if chunks are identical. If so, recurse into
// the child node. Otherwise, the key has to be split
// into 1) an optional common prefix, 2) the fullnode
Expand Down Expand Up @@ -551,57 +552,85 @@ func (st *StackTrie) Commit() (common.Hash, error) {
return common.BytesToHash(st.val), nil
}

func (st *StackTrie) getNodeFromBranchRLP(branch []byte, ind byte) []byte {
start := 2 // when branch[0] == 248
if branch[0] == 249 {
start = 3
}

i := 0
insideInd := -1
cInd := byte(0)
for {
if start+i == len(branch)-1 { // -1 because of the last 128 (branch value)
return []byte{0}
}
b := branch[start+i]
if insideInd == -1 && b == 128 {
if cInd == ind {
const RLP_SHORT_STR_FLAG = 128
const RLP_SHORT_LIST_FLAG = 192
const RLP_LONG_LIST_FLAG = 248
const LEN_OF_HASH = 32

// Note:
// In RLP encoding, if the value is between [0x80, 0xb7] ([128, 183]),
// it means following data is a short string (0 - 55bytes).
// Which implies if the value is 128, it's an empty string.
func (st *StackTrie) getNodeFromBranchRLP(branch []byte, idx int) []byte {

start := int(branch[0])
start_idx := 0
if start >= RLP_SHORT_LIST_FLAG && start < RLP_LONG_LIST_FLAG {
// In RLP encoding, length in the range of [192 248] is a short list.
// In stack trie, it usually means an extension node and the first byte is nibble
// and that's why we start from 2
start_idx = 2
} else if start >= RLP_LONG_LIST_FLAG {
// In RLP encoding, length in the range of [248 ~ ] is a long list.
// The RLP byte minus 248 (branch[0] - 248) is the length in bytes of the length of the payload
// and the payload is right after the length.
// That's why we add 2 here
// e.g. [248 81 128 160 ...]
// `81` is the length of the payload and payload starts from `128`
start_idx = start - RLP_LONG_LIST_FLAG + 2
}

// If 1st node is neither 128(empty node) nor 160, it should be a leaf
b := int(branch[start_idx])
if b != RLP_SHORT_STR_FLAG && b != (RLP_SHORT_STR_FLAG+LEN_OF_HASH) {
return []byte{0}
}

current_idx := 0
for i := start_idx; i < len(branch); i++ {
b = int(branch[i])
switch b {
case RLP_SHORT_STR_FLAG: // 128
// if the current index is we're looking for, return an empty node directly
if current_idx == idx {
return []byte{128}
} else {
cInd += 1
}
} else if insideInd == -1 && b != 128 {
if b == 160 {
if cInd == ind {
return branch[start+i+1 : start+i+1+32]
}
insideInd = 32
} else {
// non-hashed node
if cInd == ind {
return branch[start+i+1 : start+i+1+int(b)-192]
}
insideInd = int(b) - 192
current_idx++
case RLP_SHORT_STR_FLAG + LEN_OF_HASH: // 160
if current_idx == idx {
return branch[i+1 : i+1+LEN_OF_HASH]
}
cInd += 1
} else {
if insideInd == 1 {
insideInd = -1
} else {
insideInd--
// jump to next encoded element
i += LEN_OF_HASH
current_idx++
default:
if b >= 192 && b < 248 {
length := b - 192
if current_idx == idx {
return branch[i+1 : i+1+length]
}
i += length
current_idx++
}
}

i++
}

return []byte{0}
}

type StackProof struct {
proofS [][]byte
proofC [][]byte
}

func (sp *StackProof) GetProofS() [][]byte {
return sp.proofS
}

func (sp *StackProof) GetProofC() [][]byte {
return sp.proofC
}

func (st *StackTrie) UpdateAndGetProof(db ethdb.KeyValueReader, indexBuf, value []byte) (StackProof, error) {
proofS, err := st.GetProof(db, indexBuf)
if err != nil {
Expand All @@ -618,6 +647,8 @@ func (st *StackTrie) UpdateAndGetProof(db ethdb.KeyValueReader, indexBuf, value
return StackProof{proofS, proofC}, nil
}

// We refer to the link below for this function.
// https://github.com/ethereum/go-ethereum/blob/00905f7dc406cfb67f64cd74113777044fb886d8/core/types/hashing.go#L105-L134
func (st *StackTrie) UpdateAndGetProofs(db ethdb.KeyValueReader, list types.DerivableList) ([]StackProof, error) {
valueBuf := types.EncodeBufferPool.Get().(*bytes.Buffer)
defer types.EncodeBufferPool.Put(valueBuf)
Expand All @@ -631,33 +662,40 @@ func (st *StackTrie) UpdateAndGetProofs(db ethdb.KeyValueReader, list types.Deri
for i := 1; i < list.Len() && i <= 0x7f; i++ {
indexBuf = rlp.AppendUint64(indexBuf[:0], uint64(i))
value := types.EncodeForDerive(list, i, valueBuf)

proof, err := st.UpdateAndGetProof(db, indexBuf, value)
if err != nil {
return nil, err
}

proofs = append(proofs, proof)
}

// special case when index is 0
// rlp.AppendUint64() encodes index 0 to [128]
if list.Len() > 0 {
indexBuf = rlp.AppendUint64(indexBuf[:0], 0)
value := types.EncodeForDerive(list, 0, valueBuf)
// TODO: get proof
st.Update(indexBuf, value)
proof, err := st.UpdateAndGetProof(db, indexBuf, value)
if err != nil {
return nil, err
}
proofs = append(proofs, proof)
}

for i := 0x80; i < list.Len(); i++ {
indexBuf = rlp.AppendUint64(indexBuf[:0], uint64(i))
value := types.EncodeForDerive(list, i, valueBuf)
// TODO: get proof
st.Update(indexBuf, value)
proof, err := st.UpdateAndGetProof(db, indexBuf, value)
if err != nil {
return nil, err
}
proofs = append(proofs, proof)
}

return proofs, nil
}

func (st *StackTrie) GetProof(db ethdb.KeyValueReader, key []byte) ([][]byte, error) {
k := KeybytesToHex(key)

if st.nodeType == emptyNode {
return [][]byte{}, nil
}
Expand All @@ -682,7 +720,8 @@ func (st *StackTrie) GetProof(db ethdb.KeyValueReader, key []byte) ([][]byte, er
for i := 0; i < len(k); i++ {
if c.nodeType == extNode {
nodes = append(nodes, c)
c = st.children[0]
c = c.children[0]

} else if c.nodeType == branchNode {
nodes = append(nodes, c)
c = c.children[k[i]]
Expand All @@ -700,11 +739,11 @@ func (st *StackTrie) GetProof(db ethdb.KeyValueReader, key []byte) ([][]byte, er
}

proof = append(proof, c_rlp)
branchChild := st.getNodeFromBranchRLP(c_rlp, k[i])
branchChild := st.getNodeFromBranchRLP(c_rlp, int(k[i]))

// branchChild is of length 1 when there is no child at this position in the branch
// (`branchChild = [128]` in this case), but it is also of length 1 when `c_rlp` is a leaf.
if len(branchChild) == 1 {
if len(branchChild) == 1 && (branchChild[0] == 128 || branchChild[0] == 0) {
// no child at this position - 128 is RLP encoding for nil object
break
}
Expand Down
Loading

0 comments on commit b82ea41

Please sign in to comment.