Skip to content

Commit

Permalink
Trie - async walk generator (#2904)
Browse files Browse the repository at this point in the history
* Trie:  async iterator _walkTrie function

* write test/demo script

* Trie: internalize walkTrieIterable into Trie class

* Trie: include helper methods for all nodes / value nodes

* remove null conditional

* update test with sparse trie example

---------

Co-authored-by: acolytec3 <17355484+acolytec3@users.noreply.github.com>
Co-authored-by: Holger Drewes <Holger.Drewes@gmail.com>
  • Loading branch information
3 people authored Jul 26, 2023
1 parent 191faf5 commit e984704
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 0 deletions.
33 changes: 33 additions & 0 deletions packages/trie/src/trie.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import {
} from './node/index.js'
import { verifyRangeProof } from './proof/range.js'
import { ROOT_DB_KEY } from './types.js'
import { _walkTrie } from './util/asyncWalk.js'
import { Lock } from './util/lock.js'
import { bytesToNibbles, doKeysMatch, matchingNibbleLength } from './util/nibbles.js'
import { TrieReadStream as ReadStream } from './util/readStream.js'
Expand All @@ -35,6 +36,7 @@ import type {
TrieOpts,
TrieOptsWithDefaults,
} from './types.js'
import type { OnFound } from './util/asyncWalk.js'
import type { BatchDBOp, DB, PutBatch } from '@ethereumjs/util'

interface Path {
Expand Down Expand Up @@ -350,6 +352,37 @@ export class Trie {
await WalkController.newWalk(onFound, this, root)
}

walkTrieIterable = _walkTrie.bind(this)

/**
* Executes a callback for each node in the trie.
* @param onFound - callback to call when a node is found.
* @returns Resolves when finished walking trie.
*/
async walkAllNodes(onFound: OnFound): Promise<void> {
for await (const { node, currentKey } of this.walkTrieIterable(this.root())) {
await onFound(node, currentKey)
}
}

/**
* Executes a callback for each value node in the trie.
* @param onFound - callback to call when a node is found.
* @returns Resolves when finished walking trie.
*/
async walkAllValueNodes(onFound: OnFound): Promise<void> {
for await (const { node, currentKey } of this.walkTrieIterable(
this.root(),
[],
undefined,
async (node) => {
return node instanceof LeafNode || (node instanceof BranchNode && node.value() !== null)
}
)) {
await onFound(node, currentKey)
}
}

/**
* Creates the initial node from an empty tree.
* @private
Expand Down
60 changes: 60 additions & 0 deletions packages/trie/src/util/asyncWalk.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import { RLP } from '@ethereumjs/rlp'
import { equalsBytes, toHex } from 'ethereum-cryptography/utils'

import { BranchNode } from '../node/branch.js'
import { ExtensionNode } from '../node/extension.js'

import type { Trie } from '../trie'
import type { TrieNode } from '../types'

export type NodeFilter = (node: TrieNode, key: number[]) => Promise<boolean>
export type OnFound = (node: TrieNode, key: number[]) => Promise<any>

/**
* Walk Trie via async generator
* @param nodeHash - The root key to walk on.
* @param currentKey - The current (partial) key.
* @param onFound - Called on every node found (before filter)
* @param filter - Filter nodes yielded by the generator.
* @param visited - Set of visited nodes
* @returns AsyncIterable<{ node: TrieNode; currentKey: number[] }>
* Iterate through nodes with
* `for await (const { node, currentKey } of trie._walkTrie(root)) { ... }`
*/
export async function* _walkTrie(
this: Trie,
nodeHash: Uint8Array,
currentKey: number[] = [],
onFound: OnFound = async (_trieNode: TrieNode, _key: number[]) => {},
filter: NodeFilter = async (_trieNode: TrieNode, _key: number[]) => true,
visited: Set<string> = new Set<string>()
): AsyncIterable<{ node: TrieNode; currentKey: number[] }> {
if (equalsBytes(nodeHash, this.EMPTY_TRIE_ROOT)) {
return
}
try {
const node = await this.lookupNode(nodeHash)
if (node === undefined || visited.has(toHex(this.hash(node!.serialize())))) {
return
}
visited.add(toHex(this.hash(node!.serialize())))
await onFound(node!, currentKey)
if (await filter(node!, currentKey)) {
yield { node: node!, currentKey }
}
if (node instanceof BranchNode) {
for (const [nibble, childNode] of node._branches.entries()) {
const nextKey = [...currentKey, nibble]
const _childNode: Uint8Array =
childNode instanceof Uint8Array ? childNode : this.hash(RLP.encode(childNode))
yield* _walkTrie.bind(this)(_childNode, nextKey, onFound, filter, visited)
}
} else if (node instanceof ExtensionNode) {
const childNode = node.value()
const nextKey = [...currentKey, ...node._nibbles]
yield* _walkTrie.bind(this)(childNode, nextKey, onFound, filter, visited)
}
} catch (e) {
return
}
}
183 changes: 183 additions & 0 deletions packages/trie/test/util/asyncWalk.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
/* eslint-disable no-console */
import { bytesToHex, equalsBytes, hexToBytes, utf8ToBytes } from '@ethereumjs/util'
import { assert, describe, it } from 'vitest'

import { BranchNode, ExtensionNode, LeafNode, Trie } from '../../src/index.js'
import { _walkTrie } from '../../src/util/asyncWalk.js'
import { bytesToNibbles } from '../../src/util/nibbles.js'
import trieTests from '../fixtures/trietest.json'

import type { TrieNode } from '../../src/index.js'

function getNodeType(node: TrieNode): string {
if (node instanceof BranchNode) {
return 'BranchNode'
}
if (node instanceof ExtensionNode) {
return 'Ext_Node'
}
if (node instanceof LeafNode) {
return 'LeafNode'
}
throw new Error(`Unknown node type: ${node}`)
}

function logNode(trie: Trie, node: TrieNode, currentKey: number[]): void {
console.log('--------------------------')
console.log(`------- \u2705 { ${getNodeType(node)} } \u2705 `)
if (equalsBytes((trie as any).hash(node.serialize()), trie.root())) {
console.log(
`{ 0x${bytesToHex((trie as any).hash(node.serialize())).slice(
0,
12
)}... } ---- \uD83D\uDCA5 \u211B \u2134 \u2134 \u0164 \u0147 \u2134 \u0221 \u2211 \u2737`
)
} else {
console.log(`{ 0x${bytesToHex((trie as any).hash(node.serialize())).slice(0, 12)}... } ----`)
}
console.log(
'walk from',
`[${currentKey}]`,
node instanceof ExtensionNode ? `((${node._nibbles}))` : ''
)
if ('_nibbles' in node) {
console.log(` -- to =>`, `[${node._nibbles}]`)
console.log(` -- next key: [${[...currentKey, node._nibbles]}]`)
} else if ('_branches' in node) {
let first = true
for (const k of [...node._branches.entries()]
.filter(([_, child]) => child !== null && child.length > 0)
.map(([nibble, _]) => nibble)) {
first || console.log('\uD83D\uDDD8 \u0026')
first = false
console.log(` -- to =>`, `[${k}]`)
console.log(` -- next key: [${[...currentKey, [k]]}]`)
}
}
console.log('--------------------------')
}

describe('walk the tries from official tests', async () => {
const testNames = Object.keys(trieTests.tests)

for await (const testName of testNames) {
const trie = new Trie()
describe(testName, async () => {
const inputs = (trieTests as any).tests[testName].in
const expect = (trieTests as any).tests[testName].root
const testKeys: Map<string, Uint8Array | null> = new Map()
const testStrings: Map<string, [string, string | null]> = new Map()
for await (const [idx, input] of inputs.entries()) {
const stringPair: [string, string] = [inputs[idx][0], inputs[idx][1] ?? 'null']
describe(`put: ${stringPair}`, async () => {
for (let i = 0; i < 2; i++) {
if (typeof input[i] === 'string' && input[i].slice(0, 2) === '0x') {
input[i] = hexToBytes(input[i])
} else if (typeof input[i] === 'string') {
input[i] = utf8ToBytes(input[i])
}
}
try {
await trie.put(input[0], input[1])
assert(true)
} catch (e) {
assert(false, (e as any).message)
}
trie.checkpoint()
await trie.commit()
trie.flushCheckpoints()
testKeys.set(bytesToHex(input[0]), input[1])
testStrings.set(bytesToHex(input[0]), stringPair)
describe(`should get all keys`, async () => {
for await (const [key, val] of testKeys.entries()) {
const retrieved = await trie.get(hexToBytes(key))
it(`should get ${testStrings.get(key)}`, async () => {
assert.deepEqual(retrieved, val)
})
}
})
})
}
it(`should have root ${expect}`, async () => {
assert.equal(bytesToHex(trie.root()), expect)
})
describe('walkTrie', async () => {
const walker = _walkTrie.bind(trie)(trie.root(), [])
console.log(`----------- { { { test: ${testName} } } } ---------`)
testName === 'branchingTests' &&
console.log(
` \uD83C\uDF10 \u267B this trie should be empty \u267B \uD83C\uDF10 `
)
testName === 'branchingTests' && console.log('--------------------------')

for await (const { currentKey, node } of walker) {
logNode(trie, node, currentKey)
}

it('should be done', async () => {
assert.equal(true, true)
})
})
})
}
})

describe('walk a sparse trie', async () => {
const trie = new Trie()
const inputs = (trieTests as any).tests.jeff.in
const expect = (trieTests as any).tests.jeff.root

// Build a Trie
for await (const input of inputs) {
for (let i = 0; i < 2; i++) {
if (typeof input[i] === 'string' && input[i].slice(0, 2) === '0x') {
input[i] = hexToBytes(input[i])
} else if (typeof input[i] === 'string') {
input[i] = utf8ToBytes(input[i])
}
}
await trie.put(input[0], input[1])
}
// Check the root
it(`should have root ${expect}`, async () => {
assert.equal(bytesToHex(trie.root()), expect)
})
// Generate a proof for inputs[0]
const proofKey = inputs[0][0]
const proof = await trie.createProof(proofKey)
assert.ok(await trie.verifyProof(trie.root(), proofKey, proof))

// Build a sparse trie from the proof
const fromProof = new Trie()
await fromProof.fromProof(proof)

// Walk the sparse trie
const walker = fromProof.walkTrieIterable(fromProof.root())
let found = 0
for await (const { currentKey, node } of walker) {
if (equalsBytes((fromProof as any).hash(node.serialize()), fromProof.root())) {
// The root of proof trie should be same as original
assert.deepEqual(fromProof.root(), trie.root())
}
if (node instanceof LeafNode) {
// The only leaf node should be leaf from the proof
const fullKeyNibbles = [...currentKey, ...node._nibbles]
assert.deepEqual(fullKeyNibbles, bytesToNibbles(proofKey))
assert.deepEqual(node.value(), inputs[0][1])
}
// Count the nodes...nodes from the proof should be only nodes in the trie
found++
}
assert.equal(found, proof.length)
assert.ok(true, 'Walking sparse trie should not throw error')

// Walk the same sparse trie with WalkController
try {
await fromProof.walkTrie(fromProof.root(), async (noderef, node, key, wc) => {
wc.allChildren(node!)
})
assert.fail('Will throw when it meets a missing node in a sparse trie')
} catch (err) {
assert.equal((err as any).message, 'Missing node in DB')
}
})

0 comments on commit e984704

Please sign in to comment.