Skip to content

Commit

Permalink
Pluggable encoding for nodes (#317)
Browse files Browse the repository at this point in the history
  • Loading branch information
xinifinity authored Dec 1, 2023
1 parent 95f37a2 commit dfb33e7
Show file tree
Hide file tree
Showing 6 changed files with 337 additions and 47 deletions.
8 changes: 4 additions & 4 deletions firewood/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub use crate::{
};
use crate::{
file,
merkle::{Merkle, MerkleError, Node, Proof, ProofError, TrieHash, TRIE_HASH_LEN},
merkle::{Bincode, Merkle, MerkleError, Node, Proof, ProofError, TrieHash, TRIE_HASH_LEN},
storage::{
buffer::{DiskBuffer, DiskBufferRequester},
CachedSpace, MemStoreR, SpaceWrite, StoreConfig, StoreDelta, StoreRevMut, StoreRevShared,
Expand Down Expand Up @@ -275,7 +275,7 @@ impl<T: MemStoreR + 'static> Universe<Arc<T>> {
#[derive(Debug)]
pub struct DbRev<S> {
header: shale::Obj<DbHeader>,
merkle: Merkle<S>,
merkle: Merkle<S, Bincode>,
}

#[async_trait]
Expand Down Expand Up @@ -321,7 +321,7 @@ impl<S: ShaleStore<Node> + Send + Sync> DbRev<S> {
pub fn stream<K: KeyType>(
&self,
start_key: Option<K>,
) -> Result<merkle::MerkleKeyValueStream<'_, S>, api::Error> {
) -> Result<merkle::MerkleKeyValueStream<'_, S, Bincode>, api::Error> {
self.merkle
.get_iter(start_key, self.header.kv_root)
.map_err(|e| api::Error::InternalError(Box::new(e)))
Expand All @@ -333,7 +333,7 @@ impl<S: ShaleStore<Node> + Send + Sync> DbRev<S> {
Some(())
}

fn borrow_split(&mut self) -> (&mut shale::Obj<DbHeader>, &mut Merkle<S>) {
fn borrow_split(&mut self) -> (&mut shale::Obj<DbHeader>, &mut Merkle<S, Bincode>) {
(&mut self.header, &mut self.merkle)
}

Expand Down
132 changes: 109 additions & 23 deletions firewood/src/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@ use crate::v2::api;
use futures::{Stream, StreamExt, TryStreamExt};
use sha3::Digest;
use std::{
cmp::Ordering, collections::HashMap, future::ready, io::Write, iter::once, sync::OnceLock,
task::Poll,
cmp::Ordering, collections::HashMap, future::ready, io::Write, iter::once, marker::PhantomData,
sync::OnceLock, task::Poll,
};
use thiserror::Error;

mod node;
pub mod proof;
mod trie_hash;

pub use node::{BranchNode, Data, ExtNode, LeafNode, Node, NodeType, PartialPath};
pub use node::{
BinarySerde, Bincode, BranchNode, Data, EncodedNode, EncodedNodeType, ExtNode, LeafNode, Node,
NodeType, PartialPath,
};
pub use proof::{Proof, ProofError};
pub use trie_hash::{TrieHash, TRIE_HASH_LEN};

Expand All @@ -39,6 +42,8 @@ pub enum MerkleError {
UnsetInternal,
#[error("error updating nodes: {0}")]
WriteError(#[from] ObjWriteError),
#[error("merkle serde error: {0}")]
BinarySerdeError(String),
}

macro_rules! write_node {
Expand All @@ -55,11 +60,12 @@ macro_rules! write_node {
}

#[derive(Debug)]
pub struct Merkle<S> {
pub struct Merkle<S, T> {
store: Box<S>,
phantom: PhantomData<T>,
}

impl<S: ShaleStore<Node>> Merkle<S> {
impl<S: ShaleStore<Node>, T> Merkle<S, T> {
pub fn get_node(&self, ptr: DiskAddress) -> Result<ObjRef, MerkleError> {
self.store.get_item(ptr).map_err(Into::into)
}
Expand All @@ -73,11 +79,72 @@ impl<S: ShaleStore<Node>> Merkle<S> {
}
}

impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
impl<'de, S, T> Merkle<S, T>
where
S: ShaleStore<Node> + Send + Sync,
T: BinarySerde,
EncodedNode<T>: serde::Serialize + serde::Deserialize<'de>,
{
pub fn new(store: Box<S>) -> Self {
Self { store }
Self {
store,
phantom: PhantomData,
}
}

// TODO: use `encode` / `decode` instead of `node.encode` / `node.decode` after extention node removal.
#[allow(dead_code)]
fn encode(&self, node: &ObjRef) -> Result<Vec<u8>, MerkleError> {
let encoded = match node.inner() {
NodeType::Leaf(n) => EncodedNode::new(EncodedNodeType::Leaf(n.clone())),
NodeType::Branch(n) => {
// pair up DiskAddresses with encoded children and pick the right one
let encoded_children = n.chd().iter().zip(n.children_encoded.iter());
let children = encoded_children
.map(|(child_addr, encoded_child)| {
child_addr
// if there's a child disk address here, get the encoded bytes
.map(|addr| self.get_node(addr).and_then(|node| self.encode(&node)))
// or look for the pre-fetched bytes
.or_else(|| encoded_child.as_ref().map(|child| Ok(child.to_vec())))
.transpose()
})
.collect::<Result<Vec<Option<Vec<u8>>>, MerkleError>>()?
.try_into()
.expect("MAX_CHILDREN will always be yielded");

EncodedNode::new(EncodedNodeType::Branch {
children,
value: n.value.clone(),
})
}

NodeType::Extension(_) => todo!(),
};

Bincode::serialize(&encoded).map_err(|e| MerkleError::BinarySerdeError(e.to_string()))
}

#[allow(dead_code)]
fn decode(&self, buf: &'de [u8]) -> Result<NodeType, MerkleError> {
let encoded: EncodedNode<T> =
T::deserialize(buf).map_err(|e| MerkleError::BinarySerdeError(e.to_string()))?;

match encoded.node {
EncodedNodeType::Leaf(leaf) => Ok(NodeType::Leaf(leaf)),
EncodedNodeType::Branch { children, value } => {
let path = Vec::new().into();
let value = value.map(|v| v.0);
Ok(NodeType::Branch(
BranchNode::new(path, [None; BranchNode::MAX_CHILDREN], value, *children)
.into(),
))
}
}
}
}

impl<S: ShaleStore<Node> + Send + Sync, T> Merkle<S, T> {
pub fn init_root(&self) -> Result<DiskAddress, MerkleError> {
self.store
.put_item(
Expand Down Expand Up @@ -1072,7 +1139,7 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
&mut self,
key: K,
root: DiskAddress,
) -> Result<Option<RefMut<S>>, MerkleError> {
) -> Result<Option<RefMut<S, T>>, MerkleError> {
if root.is_null() {
return Ok(None);
}
Expand Down Expand Up @@ -1151,7 +1218,7 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
&self,
key: Option<K>,
root: DiskAddress,
) -> Result<MerkleKeyValueStream<'_, S>, MerkleError> {
) -> Result<MerkleKeyValueStream<'_, S, T>, MerkleError> {
Ok(MerkleKeyValueStream {
key_state: IteratorState::new(key),
merkle_root: root,
Expand Down Expand Up @@ -1269,13 +1336,15 @@ impl<'a> Default for IteratorState<'a> {
/// This iterator is not fused. If you read past the None value, you start
/// over at the beginning. If you need a fused iterator, consider using
/// std::iter::fuse
pub struct MerkleKeyValueStream<'a, S> {
pub struct MerkleKeyValueStream<'a, S, T> {
key_state: IteratorState<'a>,
merkle_root: DiskAddress,
merkle: &'a Merkle<S>,
merkle: &'a Merkle<S, T>,
}

impl<'a, S: shale::ShaleStore<node::Node> + Send + Sync> Stream for MerkleKeyValueStream<'a, S> {
impl<'a, S: shale::ShaleStore<node::Node> + Send + Sync, T> Stream
for MerkleKeyValueStream<'a, S, T>
{
type Item = Result<(Vec<u8>, Vec<u8>), api::Error>;

fn poll_next(
Expand Down Expand Up @@ -1531,10 +1600,10 @@ fn set_parent(new_chd: DiskAddress, parents: &mut [(ObjRef, u8)]) {

pub struct Ref<'a>(ObjRef<'a>);

pub struct RefMut<'a, S> {
pub struct RefMut<'a, S, T> {
ptr: DiskAddress,
parents: ParentAddresses,
merkle: &'a mut Merkle<S>,
merkle: &'a mut Merkle<S, T>,
}

impl<'a> std::ops::Deref for Ref<'a> {
Expand All @@ -1548,8 +1617,8 @@ impl<'a> std::ops::Deref for Ref<'a> {
}
}

impl<'a, S: ShaleStore<Node> + Send + Sync> RefMut<'a, S> {
fn new(ptr: DiskAddress, parents: ParentAddresses, merkle: &'a mut Merkle<S>) -> Self {
impl<'a, S: ShaleStore<Node> + Send + Sync, T> RefMut<'a, S, T> {
fn new(ptr: DiskAddress, parents: ParentAddresses, merkle: &'a mut Merkle<S, T>) -> Self {
Self {
ptr,
parents,
Expand Down Expand Up @@ -1621,7 +1690,7 @@ mod tests {
assert_eq!(n, nibbles);
}

fn create_test_merkle() -> Merkle<CompactSpace<Node, DynamicMem>> {
fn create_test_merkle() -> Merkle<CompactSpace<Node, DynamicMem>, Bincode> {
const RESERVED: usize = 0x1000;

let mut dm = shale::cached::DynamicMem::new(0x10000, 0);
Expand Down Expand Up @@ -1669,9 +1738,11 @@ mod tests {
})
}

#[test_case(leaf(Vec::new(), Vec::new()) ; "empty leaf encoding")]
#[test_case(leaf(vec![1, 2, 3], vec![4, 5]) ; "leaf encoding")]
#[test_case(branch(b"value".to_vec(), vec![1, 2, 3].into()) ; "branch with value")]
#[test_case(branch(b"value".to_vec(), None); "branch without value")]
#[test_case(branch(b"value".to_vec(), vec![1, 2, 3].into()) ; "branch with chd")]
#[test_case(branch(b"value".to_vec(), None); "branch without chd")]
#[test_case(branch(Vec::new(), None); "branch without value and chd")]
#[test_case(extension(vec![1, 2, 3], DiskAddress::null(), vec![4, 5].into()) ; "extension without child address")]
fn encode_(node: Node) {
let merkle = create_test_merkle();
Expand All @@ -1684,6 +1755,21 @@ mod tests {
assert_eq!(encoded, new_node_encoded);
}

#[test_case(leaf(Vec::new(), Vec::new()) ; "empty leaf encoding")]
#[test_case(leaf(vec![1, 2, 3], vec![4, 5]) ; "leaf encoding")]
#[test_case(branch(b"value".to_vec(), vec![1, 2, 3].into()) ; "branch with chd")]
#[test_case(branch(b"value".to_vec(), None); "branch without chd")]
#[test_case(branch(Vec::new(), None); "branch without value and chd")]
fn node_encode_decode_(node: Node) {
let merkle = create_test_merkle();
let node_ref = merkle.put_node(node.clone()).unwrap();
let encoded = merkle.encode(&node_ref).unwrap();
let new_node = Node::from(merkle.decode(encoded.as_ref()).unwrap());
let new_node_hash = new_node.get_root_hash(merkle.store.as_ref());
let expected_hash = node.get_root_hash(merkle.store.as_ref());
assert_eq!(new_node_hash, expected_hash);
}

#[test]
fn insert_and_retrieve_one() {
let key = b"hello";
Expand Down Expand Up @@ -1793,7 +1879,7 @@ mod tests {

#[test]
fn remove_many() {
let mut merkle: Merkle<CompactSpace<Node, DynamicMem>> = create_test_merkle();
let mut merkle = create_test_merkle();
let root = merkle.init_root().unwrap();

// insert values
Expand Down Expand Up @@ -1834,7 +1920,7 @@ mod tests {

#[tokio::test]
async fn empty_range_proof() {
let merkle: Merkle<CompactSpace<Node, DynamicMem>> = create_test_merkle();
let merkle = create_test_merkle();
let root = merkle.init_root().unwrap();

assert!(merkle
Expand All @@ -1846,7 +1932,7 @@ mod tests {

#[tokio::test]
async fn full_range_proof() {
let mut merkle: Merkle<CompactSpace<Node, DynamicMem>> = create_test_merkle();
let mut merkle = create_test_merkle();
let root = merkle.init_root().unwrap();
// insert values
for key_val in u8::MIN..=u8::MAX {
Expand Down Expand Up @@ -1874,7 +1960,7 @@ mod tests {
async fn single_value_range_proof() {
const RANDOM_KEY: u8 = 42;

let mut merkle: Merkle<CompactSpace<Node, DynamicMem>> = create_test_merkle();
let mut merkle = create_test_merkle();
let root = merkle.init_root().unwrap();
// insert values
for key_val in u8::MIN..=u8::MAX {
Expand Down
Loading

0 comments on commit dfb33e7

Please sign in to comment.