diff --git a/firewood/benches/hashops.rs b/firewood/benches/hashops.rs index 93dc9b3d4..a03429117 100644 --- a/firewood/benches/hashops.rs +++ b/firewood/benches/hashops.rs @@ -6,7 +6,7 @@ use criterion::{criterion_group, criterion_main, profiler::Profiler, BatchSize, Criterion}; use firewood::{ db::{BatchOp, DbConfig}, - merkle::{Merkle, TrieHash, TRIE_HASH_LEN}, + merkle::{MerkleWithEncoder, TrieHash, TRIE_HASH_LEN}, shale::{ cached::PlainMem, compact::{CompactHeader, CompactSpace}, @@ -104,7 +104,7 @@ fn bench_merkle(criterion: &mut Criterion) { ) .unwrap(); - let merkle = Merkle::new(Box::new(store)); + let merkle = MerkleWithEncoder::new(Box::new(store)); let root = merkle.init_root().unwrap(); let keys: Vec> = repeat_with(|| { diff --git a/firewood/src/merkle.rs b/firewood/src/merkle.rs index 7e64c3f02..f567c193f 100644 --- a/firewood/src/merkle.rs +++ b/firewood/src/merkle.rs @@ -1,7 +1,10 @@ // Copyright (C) 2023, Ava Labs, Inc. All rights reserved. // See the file LICENSE.md for licensing terms. use crate::nibbles::Nibbles; -use crate::shale::{self, disk_address::DiskAddress, ObjWriteError, ShaleError, ShaleStore}; +use crate::shale::{ + self, cached::PlainMem, compact::CompactSpace, disk_address::DiskAddress, ObjWriteError, + ShaleError, ShaleStore, +}; use crate::v2::api; use futures::{StreamExt, TryStreamExt}; use sha3::Digest; @@ -28,6 +31,7 @@ pub use trie_hash::{TrieHash, TRIE_HASH_LEN}; type ObjRef<'a> = shale::ObjRef<'a, Node>; type ParentRefs<'a> = Vec<(ObjRef<'a>, u8)>; type ParentAddresses = Vec<(DiskAddress, u8)>; +pub type MerkleWithEncoder = Merkle, Bincode>; #[derive(Debug, Error)] pub enum MerkleError { @@ -1399,6 +1403,8 @@ pub fn from_nibbles(nibbles: &[u8]) -> impl Iterator + '_ { #[cfg(test)] mod tests { + use crate::merkle::node::PlainCodec; + use super::*; use node::tests::{extension, leaf}; use shale::{cached::DynamicMem, compact::CompactSpace, CachedStore}; @@ -1412,7 +1418,11 @@ mod tests { assert_eq!(n, nibbles); } - pub(super) fn create_test_merkle() -> Merkle, Bincode> { + fn create_generic_test_merkle<'de, T>() -> Merkle, T> + where + T: BinarySerde, + EncodedNode: serde::Serialize + serde::Deserialize<'de>, + { const RESERVED: usize = 0x1000; let mut dm = shale::cached::DynamicMem::new(0x10000, 0); @@ -1443,9 +1453,13 @@ mod tests { Merkle::new(store) } - fn branch(value: Vec, encoded_child: Option>) -> Node { + pub(super) fn create_test_merkle() -> Merkle, Bincode> { + create_generic_test_merkle::() + } + + fn branch(value: Option>, encoded_child: Option>) -> Node { let children = Default::default(); - let value = Some(Data(value)); + let value = value.map(Data); let mut children_encoded = <[Option>; BranchNode::MAX_CHILDREN]>::default(); if let Some(child) = encoded_child { @@ -1462,9 +1476,9 @@ mod tests { #[test_case(leaf(Vec::new(), Vec::new()) ; "empty leaf encoding")] #[test_case(leaf(vec![1, 2, 3], vec![4, 5]) ; "leaf encoding")] - #[test_case(branch(b"value".to_vec(), vec![1, 2, 3].into()) ; "branch with chd")] - #[test_case(branch(b"value".to_vec(), None); "branch without chd")] - #[test_case(branch(Vec::new(), None); "branch without value and chd")] + #[test_case(branch(Some(b"value".to_vec()), vec![1, 2, 3].into()) ; "branch with chd")] + #[test_case(branch(Some(b"value".to_vec()), None); "branch without chd")] + #[test_case(branch(None, None); "branch without value and chd")] #[test_case(extension(vec![1, 2, 3], DiskAddress::null(), vec![4, 5].into()) ; "extension without child address")] fn encode_(node: Node) { let merkle = create_test_merkle(); @@ -1479,17 +1493,32 @@ mod tests { #[test_case(leaf(Vec::new(), Vec::new()) ; "empty leaf encoding")] #[test_case(leaf(vec![1, 2, 3], vec![4, 5]) ; "leaf encoding")] - #[test_case(branch(b"value".to_vec(), vec![1, 2, 3].into()) ; "branch with chd")] - #[test_case(branch(b"value".to_vec(), None); "branch without chd")] - #[test_case(branch(Vec::new(), None); "branch without value and chd")] - fn node_encode_decode_(node: Node) { + #[test_case(branch(Some(b"value".to_vec()), vec![1, 2, 3].into()) ; "branch with chd")] + #[test_case(branch(Some(b"value".to_vec()), None); "branch without chd")] + #[test_case(branch(None, None); "branch without value and chd")] + fn node_encode_decode(node: Node) { let merkle = create_test_merkle(); let node_ref = merkle.put_node(node.clone()).unwrap(); + let encoded = merkle.encode(&node_ref).unwrap(); let new_node = Node::from(merkle.decode(encoded.as_ref()).unwrap()); - let new_node_hash = new_node.get_root_hash(merkle.store.as_ref()); - let expected_hash = node.get_root_hash(merkle.store.as_ref()); - assert_eq!(new_node_hash, expected_hash); + + assert_eq!(node, new_node); + } + + #[test_case(leaf(Vec::new(), Vec::new()) ; "empty leaf encoding")] + #[test_case(leaf(vec![1, 2, 3], vec![4, 5]) ; "leaf encoding")] + #[test_case(branch(Some(b"value".to_vec()), vec![1, 2, 3].into()) ; "branch with chd")] + #[test_case(branch(Some(b"value".to_vec()), Some(Vec::new())); "branch with empty chd")] + #[test_case(branch(Some(Vec::new()), vec![1, 2, 3].into()); "branch with empty value")] + fn node_encode_decode_plain(node: Node) { + let merkle = create_generic_test_merkle::(); + let node_ref = merkle.put_node(node.clone()).unwrap(); + + let encoded = merkle.encode(&node_ref).unwrap(); + let new_node = Node::from(merkle.decode(encoded.as_ref()).unwrap()); + + assert_eq!(node, new_node); } #[test] diff --git a/firewood/src/merkle/node.rs b/firewood/src/merkle/node.rs index 5e6e9bc2d..74ac9336d 100644 --- a/firewood/src/merkle/node.rs +++ b/firewood/src/merkle/node.rs @@ -8,7 +8,11 @@ use crate::{ use bincode::{Error, Options}; use bitflags::bitflags; use enum_as_inner::EnumAsInner; -use serde::{de::DeserializeOwned, ser::SerializeSeq, Deserialize, Serialize}; +use serde::{ + de::DeserializeOwned, + ser::{SerializeSeq, SerializeTuple}, + Deserialize, Serialize, +}; use sha3::{Digest, Keccak256}; use std::{ fmt::Debug, @@ -482,6 +486,7 @@ impl EncodedNode { } } } + pub enum EncodedNodeType { Leaf(LeafNode), Branch { @@ -490,6 +495,91 @@ pub enum EncodedNodeType { }, } +// TODO: probably can merge with `EncodedNodeType`. +#[derive(Debug, Deserialize)] +struct EncodedBranchNode { + chd: Vec<(u64, Vec)>, + data: Option>, + path: Vec, +} + +// Note that the serializer passed in should always be the same type as T in EncodedNode. +impl Serialize for EncodedNode { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let n = match &self.node { + EncodedNodeType::Leaf(n) => { + let data = Some(n.data.to_vec()); + let chd: Vec<(u64, Vec)> = Default::default(); + let path = from_nibbles(&n.path.encode(true)).collect(); + EncodedBranchNode { chd, data, path } + } + EncodedNodeType::Branch { children, value } => { + let chd: Vec<(u64, Vec)> = children + .iter() + .enumerate() + .filter_map(|(i, c)| c.as_ref().map(|c| (i as u64, c))) + .map(|(i, c)| { + if c.len() >= TRIE_HASH_LEN { + (i, Keccak256::digest(c).to_vec()) + } else { + (i, c.to_vec()) + } + }) + .collect(); + + let data = value.as_ref().map(|v| v.0.to_vec()); + EncodedBranchNode { + chd, + data, + path: Vec::new(), + } + } + }; + + let mut s = serializer.serialize_tuple(3)?; + s.serialize_element(&n.chd)?; + s.serialize_element(&n.data)?; + s.serialize_element(&n.path)?; + s.end() + } +} + +impl<'de> Deserialize<'de> for EncodedNode { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let node: EncodedBranchNode = Deserialize::deserialize(deserializer)?; + if node.chd.is_empty() { + let data = if let Some(d) = node.data { + Data(d) + } else { + Data(Vec::new()) + }; + + let path = PartialPath::from_nibbles(Nibbles::<0>::new(&node.path).into_iter()).0; + let node = EncodedNodeType::Leaf(LeafNode { path, data }); + Ok(Self::new(node)) + } else { + let mut children: [Option>; BranchNode::MAX_CHILDREN] = Default::default(); + let value = node.data.map(Data); + + for (i, chd) in node.chd { + children[i as usize] = Some(chd); + } + + let node = EncodedNodeType::Branch { + children: children.into(), + value, + }; + Ok(Self::new(node)) + } + } +} + // Note that the serializer passed in should always be the same type as T in EncodedNode. impl Serialize for EncodedNode { fn serialize(&self, serializer: S) -> Result { @@ -566,10 +656,7 @@ impl<'de> Deserialize<'de> for EncodedNode { path, data: Data(data), }); - Ok(Self { - node, - phantom: PhantomData, - }) + Ok(Self::new(node)) } BranchNode::MSIZE => { let mut children: [Option>; BranchNode::MAX_CHILDREN] = Default::default(); @@ -601,10 +688,7 @@ impl<'de> Deserialize<'de> for EncodedNode { children: children.into(), value, }; - Ok(Self { - node, - phantom: PhantomData, - }) + Ok(Self::new(node)) } size => Err(D::Error::custom(format!("invalid size: {size}"))), } @@ -666,6 +750,37 @@ impl BinarySerde for Bincode { } } +pub struct PlainCodec(pub bincode::DefaultOptions); + +impl Debug for PlainCodec { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + write!(f, "PlainCodec") + } +} + +impl BinarySerde for PlainCodec { + type SerializeError = bincode::Error; + type DeserializeError = Self::SerializeError; + + fn new() -> Self { + Self(bincode::DefaultOptions::new()) + } + + fn serialize_impl(&self, t: &T) -> Result, Self::SerializeError> { + // Serializes the object directly into a Writer without include the length. + let mut writer = Vec::new(); + self.0.serialize_into(&mut writer, t)?; + Ok(writer) + } + + fn deserialize_impl<'de, T: Deserialize<'de>>( + &self, + bytes: &'de [u8], + ) -> Result { + self.0.deserialize(bytes) + } +} + #[cfg(test)] pub(super) mod tests { use std::array::from_fn;