Skip to content

Commit

Permalink
Centralize trie-traversal for proof generation (#380)
Browse files Browse the repository at this point in the history
  • Loading branch information
richardpringle authored Dec 1, 2023
1 parent fedc673 commit 11285ca
Show file tree
Hide file tree
Showing 5 changed files with 314 additions and 138 deletions.
327 changes: 256 additions & 71 deletions firewood/src/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,7 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
node_ref: ObjRef<'a>,
key: K,
) -> Result<Option<ObjRef<'a>>, MerkleError> {
self.get_node_by_key_with_callback(node_ref, key, |_, _| {})
self.get_node_by_key_with_callbacks(node_ref, key, |_, _| {}, |_, _| {})
}

fn get_node_and_parents_by_key<'a, K: AsRef<[u8]>>(
Expand All @@ -979,9 +979,14 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
key: K,
) -> Result<(Option<ObjRef<'a>>, ParentRefs<'a>), MerkleError> {
let mut parents = Vec::new();
let node_ref = self.get_node_by_key_with_callback(node_ref, key, |node_ref, nib| {
parents.push((node_ref, nib));
})?;
let node_ref = self.get_node_by_key_with_callbacks(
node_ref,
key,
|_, _| {},
|node_ref, nib| {
parents.push((node_ref, nib));
},
)?;

Ok((node_ref, parents))
}
Expand All @@ -992,18 +997,24 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
key: K,
) -> Result<(Option<ObjRef<'a>>, ParentAddresses), MerkleError> {
let mut parents = Vec::new();
let node_ref = self.get_node_by_key_with_callback(node_ref, key, |node_ref, nib| {
parents.push((node_ref.into_ptr(), nib));
})?;
let node_ref = self.get_node_by_key_with_callbacks(
node_ref,
key,
|_, _| {},
|node_ref, nib| {
parents.push((node_ref.into_ptr(), nib));
},
)?;

Ok((node_ref, parents))
}

fn get_node_by_key_with_callback<'a, K: AsRef<[u8]>>(
fn get_node_by_key_with_callbacks<'a, K: AsRef<[u8]>>(
&'a self,
mut node_ref: ObjRef<'a>,
key: K,
mut loop_callback: impl FnMut(ObjRef<'a>, u8),
mut start_loop_callback: impl FnMut(DiskAddress, u8),
mut end_loop_callback: impl FnMut(ObjRef<'a>, u8),
) -> Result<Option<ObjRef<'a>>, MerkleError> {
let mut key_nibbles = Nibbles::<1>::new(key.as_ref()).into_iter();

Expand All @@ -1012,6 +1023,8 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
break;
};

start_loop_callback(node_ref.as_ptr(), nib);

let next_ptr = match &node_ref.inner {
NodeType::Branch(n) => match n.children[nib as usize] {
Some(c) => c,
Expand Down Expand Up @@ -1045,7 +1058,7 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
}
};

loop_callback(node_ref, nib);
end_loop_callback(node_ref, nib);

node_ref = self.get_node(next_ptr)?;
}
Expand Down Expand Up @@ -1090,78 +1103,28 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
where
K: AsRef<[u8]>,
{
let key_nibbles = Nibbles::<0>::new(key.as_ref());

let mut proofs = HashMap::new();
if root.is_null() {
return Ok(Proof(proofs));
}

// Skip the sentinel root
let root = self
.get_node(root)?
.inner
.as_branch()
.ok_or(MerkleError::NotBranchNode)?
.children[0];
let mut u_ref = match root {
Some(root) => self.get_node(root)?,
None => return Ok(Proof(proofs)),
};
let root_node = self.get_node(root)?;

let mut nskip = 0;
let mut nodes: Vec<DiskAddress> = Vec::new();
let mut nodes = Vec::new();

// TODO: use get_node_by_key (and write proper unit test)
for (i, nib) in key_nibbles.into_iter().enumerate() {
if nskip > 0 {
nskip -= 1;
continue;
}
nodes.push(u_ref.as_ptr());
let next_ptr: DiskAddress = match &u_ref.inner {
NodeType::Branch(n) => match n.children[nib as usize] {
Some(c) => c,
None => break,
},
NodeType::Leaf(_) => break,
NodeType::Extension(n) => {
// the key passed in must match the entire remainder of this
// extension node, otherwise we break out
let n_path = &n.path;
let remaining_path = key_nibbles.into_iter().skip(i);
if remaining_path.size_hint().0 < n_path.len() {
// all bytes aren't there
break;
}
if !remaining_path.take(n_path.len()).eq(n_path.iter().cloned()) {
// contents aren't the same
break;
}
nskip = n_path.len() - 1;
n.chd()
}
};
u_ref = self.get_node(next_ptr)?;
}
let node = self.get_node_by_key_with_callbacks(
root_node,
key,
|node, _| nodes.push(node),
|_, _| {},
)?;

match &u_ref.inner {
NodeType::Branch(n) => {
if n.value.as_ref().is_some() {
nodes.push(u_ref.as_ptr());
}
}
NodeType::Leaf(n) => {
if n.path.len() == 0 {
nodes.push(u_ref.as_ptr());
}
}
_ => (),
if let Some(node) = node {
nodes.push(node.as_ptr());
}

drop(u_ref);
// Get the hashes of the nodes.
for node in nodes {
for node in nodes.into_iter().skip(1) {
let node = self.get_node(node)?;
let encoded = <&[u8]>::clone(&node.get_encoded::<S>(self.store.as_ref()));
let hash: [u8; TRIE_HASH_LEN] = sha3::Keccak256::digest(encoded).into();
Expand Down Expand Up @@ -1864,6 +1827,16 @@ mod tests {
}
}

#[test]
fn get_empty_proof() {
let merkle = create_test_merkle();
let root = merkle.init_root().unwrap();

let proof = merkle.prove(b"any-key", root).unwrap();

assert!(proof.0.is_empty());
}

#[tokio::test]
async fn empty_range_proof() {
let merkle: Merkle<CompactSpace<Node, DynamicMem>> = create_test_merkle();
Expand Down Expand Up @@ -2001,4 +1974,216 @@ mod tests {
assert_eq!(fetched_val.as_deref(), val.into());
}
}

#[test]
fn overwrite_leaf() {
let key = vec![0x00];
let val = vec![1];
let overwrite = vec![2];

let mut merkle = create_test_merkle();
let root = merkle.init_root().unwrap();

merkle.insert(&key, val.clone(), root).unwrap();

assert_eq!(
merkle.get(&key, root).unwrap().as_deref(),
Some(val.as_slice())
);

merkle.insert(&key, overwrite.clone(), root).unwrap();

assert_eq!(
merkle.get(&key, root).unwrap().as_deref(),
Some(overwrite.as_slice())
);
}

#[test]
fn new_leaf_is_a_child_of_the_old_leaf() {
let key = vec![0xff];
let val = vec![1];
let key_2 = vec![0xff, 0x00];
let val_2 = vec![2];

let mut merkle = create_test_merkle();
let root = merkle.init_root().unwrap();

merkle.insert(&key, val.clone(), root).unwrap();
merkle.insert(&key_2, val_2.clone(), root).unwrap();

assert_eq!(
merkle.get(&key, root).unwrap().as_deref(),
Some(val.as_slice())
);

assert_eq!(
merkle.get(&key_2, root).unwrap().as_deref(),
Some(val_2.as_slice())
);
}

#[test]
fn old_leaf_is_a_child_of_the_new_leaf() {
let key = vec![0xff, 0x00];
let val = vec![1];
let key_2 = vec![0xff];
let val_2 = vec![2];

let mut merkle = create_test_merkle();
let root = merkle.init_root().unwrap();

merkle.insert(&key, val.clone(), root).unwrap();
merkle.insert(&key_2, val_2.clone(), root).unwrap();

assert_eq!(
merkle.get(&key, root).unwrap().as_deref(),
Some(val.as_slice())
);

assert_eq!(
merkle.get(&key_2, root).unwrap().as_deref(),
Some(val_2.as_slice())
);
}

#[test]
fn new_leaf_is_sibling_of_old_leaf() {
let key = vec![0xff];
let val = vec![1];
let key_2 = vec![0xff, 0x00];
let val_2 = vec![2];
let key_3 = vec![0xff, 0x0f];
let val_3 = vec![3];

let mut merkle = create_test_merkle();
let root = merkle.init_root().unwrap();

merkle.insert(&key, val.clone(), root).unwrap();
merkle.insert(&key_2, val_2.clone(), root).unwrap();
merkle.insert(&key_3, val_3.clone(), root).unwrap();

assert_eq!(
merkle.get(&key, root).unwrap().as_deref(),
Some(val.as_slice())
);

assert_eq!(
merkle.get(&key_2, root).unwrap().as_deref(),
Some(val_2.as_slice())
);

assert_eq!(
merkle.get(&key_3, root).unwrap().as_deref(),
Some(val_3.as_slice())
);
}

#[test]
fn old_branch_is_a_child_of_new_branch() {
let key = vec![0xff, 0xf0];
let val = vec![1];
let key_2 = vec![0xff, 0xf0, 0x00];
let val_2 = vec![2];
let key_3 = vec![0xff];
let val_3 = vec![3];

let mut merkle = create_test_merkle();
let root = merkle.init_root().unwrap();

merkle.insert(&key, val.clone(), root).unwrap();
merkle.insert(&key_2, val_2.clone(), root).unwrap();
merkle.insert(&key_3, val_3.clone(), root).unwrap();

assert_eq!(
merkle.get(&key, root).unwrap().as_deref(),
Some(val.as_slice())
);

assert_eq!(
merkle.get(&key_2, root).unwrap().as_deref(),
Some(val_2.as_slice())
);

assert_eq!(
merkle.get(&key_3, root).unwrap().as_deref(),
Some(val_3.as_slice())
);
}

#[test]
fn overlapping_branch_insert() {
let key = vec![0xff];
let val = vec![1];
let key_2 = vec![0xff, 0x00];
let val_2 = vec![2];

let overwrite = vec![3];

let mut merkle = create_test_merkle();
let root = merkle.init_root().unwrap();

merkle.insert(&key, val.clone(), root).unwrap();
merkle.insert(&key_2, val_2.clone(), root).unwrap();

assert_eq!(
merkle.get(&key, root).unwrap().as_deref(),
Some(val.as_slice())
);

assert_eq!(
merkle.get(&key_2, root).unwrap().as_deref(),
Some(val_2.as_slice())
);

merkle.insert(&key, overwrite.clone(), root).unwrap();

assert_eq!(
merkle.get(&key, root).unwrap().as_deref(),
Some(overwrite.as_slice())
);

assert_eq!(
merkle.get(&key_2, root).unwrap().as_deref(),
Some(val_2.as_slice())
);
}

#[test]
fn single_key_proof_with_one_node() {
let mut merkle = create_test_merkle();
let root = merkle.init_root().unwrap();
let key = b"key";
let value = b"value";

merkle.insert(key, value.to_vec(), root).unwrap();
let root_hash = merkle.root_hash(root).unwrap();

let proof = merkle.prove(key, root).unwrap();

let verified = proof.verify(key, root_hash.0).unwrap();

assert_eq!(verified, Some(value.to_vec()));
}

#[test]
fn two_key_proof_without_shared_path() {
let mut merkle = create_test_merkle();
let root = merkle.init_root().unwrap();

let key1 = &[0x00];
let key2 = &[0xff];

merkle.insert(key1, key1.to_vec(), root).unwrap();
merkle.insert(key2, key2.to_vec(), root).unwrap();

let root_hash = merkle.root_hash(root).unwrap();

let verified = {
let proof = merkle.prove(key1, root).unwrap();
proof.verify(key1, root_hash.0).unwrap()
};

assert_eq!(verified.as_deref(), Some(key1.as_slice()));
}
}
Loading

0 comments on commit 11285ca

Please sign in to comment.