Skip to content

Commit

Permalink
Delegate branch-storage logic to the branch mod (#373)
Browse files Browse the repository at this point in the history
  • Loading branch information
richardpringle authored Nov 29, 2023
1 parent 80a09f6 commit 4a4f4ba
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 80 deletions.
12 changes: 6 additions & 6 deletions firewood/src/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use thiserror::Error;
mod node;
mod trie_hash;

pub use node::{BranchNode, Data, ExtNode, LeafNode, Node, NodeType, PartialPath, MAX_CHILDREN};
pub use node::{BranchNode, Data, ExtNode, LeafNode, Node, NodeType, PartialPath};
pub use trie_hash::{TrieHash, TRIE_HASH_LEN};

type ObjRef<'a> = shale::ObjRef<'a, Node>;
Expand Down Expand Up @@ -81,7 +81,7 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
.put_item(
Node::from_branch(BranchNode {
// path: vec![].into(),
children: [None; MAX_CHILDREN],
children: [None; BranchNode::MAX_CHILDREN],
value: None,
children_encoded: Default::default(),
}),
Expand Down Expand Up @@ -199,7 +199,7 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
));
let leaf_address = self.put_node(new_node)?.as_ptr();

let mut chd = [None; MAX_CHILDREN];
let mut chd = [None; BranchNode::MAX_CHILDREN];

let last_matching_nibble = matching_path[idx];
chd[last_matching_nibble as usize] = Some(leaf_address);
Expand Down Expand Up @@ -340,7 +340,7 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
};

// [parent] (-> [ExtNode]) -> [branch with v] -> [Leaf]
let mut children = [None; MAX_CHILDREN];
let mut children = [None; BranchNode::MAX_CHILDREN];

children[idx] = leaf_address.into();

Expand Down Expand Up @@ -561,7 +561,7 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
};

if let Some((idx, more, ext, val)) = info {
let mut chd = [None; MAX_CHILDREN];
let mut chd = [None; BranchNode::MAX_CHILDREN];

let c_ptr = if more {
u_ptr
Expand Down Expand Up @@ -1695,7 +1695,7 @@ mod tests {
fn branch(value: Vec<u8>, encoded_child: Option<Vec<u8>>) -> Node {
let children = Default::default();
let value = Some(Data(value));
let mut children_encoded = <[Option<Vec<u8>>; MAX_CHILDREN]>::default();
let mut children_encoded = <[Option<Vec<u8>>; BranchNode::MAX_CHILDREN]>::default();

if let Some(child) = encoded_child {
children_encoded[0] = Some(child);
Expand Down
74 changes: 18 additions & 56 deletions firewood/src/merkle/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ mod extension;
mod leaf;
mod partial_path;

pub use branch::{BranchNode, MAX_CHILDREN, SIZE as BRANCH_NODE_SIZE};
pub use branch::BranchNode;
pub use extension::ExtNode;
pub use leaf::{LeafNode, SIZE as LEAF_NODE_SIZE};
pub use partial_path::PartialPath;
Expand Down Expand Up @@ -114,7 +114,7 @@ impl NodeType {
}
}
// TODO: add path
BRANCH_NODE_SIZE => Ok(NodeType::Branch(BranchNode::decode(buf)?.into())),
BranchNode::MSIZE => Ok(NodeType::Branch(BranchNode::decode(buf)?.into())),
size => Err(Box::new(bincode::ErrorKind::Custom(format!(
"invalid size: {size}"
)))),
Expand Down Expand Up @@ -216,7 +216,7 @@ impl Node {
inner: NodeType::Branch(
BranchNode {
// path: vec![].into(),
children: [Some(DiskAddress::null()); MAX_CHILDREN],
children: [Some(DiskAddress::null()); BranchNode::MAX_CHILDREN],
value: Some(Data(Vec::new())),
children_encoded: Default::default(),
}
Expand Down Expand Up @@ -370,7 +370,7 @@ impl Storable for Node {
NodeTypeId::Branch => {
// TODO: add path
// TODO: figure out what this size is?
let branch_header_size = MAX_CHILDREN as u64 * 8 + 4;
let branch_header_size = BranchNode::MAX_CHILDREN as u64 * 8 + 4;
let node_raw = mem.get_view(addr + Meta::SIZE, branch_header_size).ok_or(
ShaleError::InvalidCacheView {
offset: addr + Meta::SIZE,
Expand All @@ -379,7 +379,7 @@ impl Storable for Node {
)?;

let mut cur = Cursor::new(node_raw.as_deref());
let mut chd = [None; MAX_CHILDREN];
let mut chd = [None; BranchNode::MAX_CHILDREN];
let mut buff = [0; 8];

for chd in chd.iter_mut() {
Expand All @@ -392,12 +392,13 @@ impl Storable for Node {

cur.read_exact(&mut buff[..4])?;

let raw_len =
u32::from_le_bytes(buff[..4].try_into().expect("invalid slice")) as u64;
let raw_len = u32::from_le_bytes(buff[..4].try_into().expect("invalid slice"));

let value = if raw_len == u32::MAX as u64 {
let value = if raw_len == u32::MAX {
None
} else {
let raw_len = raw_len as u64;

Some(Data(
mem.get_view(addr + Meta::SIZE + branch_header_size as usize, raw_len)
.ok_or(ShaleError::InvalidCacheView {
Expand All @@ -408,9 +409,10 @@ impl Storable for Node {
))
};

let mut chd_encoded: [Option<Vec<u8>>; MAX_CHILDREN] = Default::default();
let mut chd_encoded: [Option<Vec<u8>>; BranchNode::MAX_CHILDREN] =
Default::default();

let offset = if raw_len == u32::MAX as u64 {
let offset = if raw_len == u32::MAX {
addr + Meta::SIZE + branch_header_size as usize
} else {
addr + Meta::SIZE + branch_header_size as usize + raw_len as usize
Expand Down Expand Up @@ -598,20 +600,7 @@ impl Storable for Node {
+ match &self.inner {
NodeType::Branch(n) => {
// TODO: add path
let mut encoded_len = 0;
for emcoded in n.children_encoded.iter() {
encoded_len += match emcoded {
Some(v) => 1 + v.len() as u64,
None => 1,
}
}
MAX_CHILDREN as u64 * 8
+ 4
+ match &n.value {
Some(val) => val.len() as u64,
None => 0,
}
+ encoded_len
n.serialized_len()
}
NodeType::Extension(n) => {
1 + 8
Expand Down Expand Up @@ -654,36 +643,9 @@ impl Storable for Node {
// TODO: add path
cur.write_all(&[type_id::NodeTypeId::Branch as u8]).unwrap();

for c in n.children.iter() {
cur.write_all(&match c {
Some(p) => p.to_le_bytes(),
None => 0u64.to_le_bytes(),
})?;
}
let pos = cur.position() as usize;

match &n.value {
Some(val) => {
cur.write_all(&(val.len() as u32).to_le_bytes())?;
cur.write_all(val)?
}
None => {
cur.write_all(&u32::MAX.to_le_bytes())?;
}
}

// Since child encoding will only be unset after initialization (only used for range proof),
// it is fine to encode its value adjacent to other fields. Same for extention node.
for encoded in n.children_encoded.iter() {
match encoded {
Some(v) => {
cur.write_all(&[v.len() as u8])?;
cur.write_all(v)?
}
None => cur.write_all(&0u8.to_le_bytes())?,
}
}

Ok(())
n.serialize(&mut cur.get_mut()[pos..])
}

NodeType::Extension(n) => {
Expand Down Expand Up @@ -734,8 +696,8 @@ pub(super) mod tests {
value: Option<Vec<u8>>,
repeated_encoded_child: Option<Vec<u8>>,
) -> Node {
let children: [Option<DiskAddress>; MAX_CHILDREN] = from_fn(|i| {
if i < MAX_CHILDREN / 2 {
let children: [Option<DiskAddress>; BranchNode::MAX_CHILDREN] = from_fn(|i| {
if i < BranchNode::MAX_CHILDREN / 2 {
DiskAddress::from(repeated_disk_address).into()
} else {
None
Expand All @@ -745,7 +707,7 @@ pub(super) mod tests {
let children_encoded = repeated_encoded_child
.map(|child| {
from_fn(|i| {
if i < MAX_CHILDREN / 2 {
if i < BranchNode::MAX_CHILDREN / 2 {
child.clone().into()
} else {
None
Expand Down
94 changes: 80 additions & 14 deletions firewood/src/merkle/node/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,21 @@
use super::{Data, Encoded, Node};
use crate::{
merkle::{PartialPath, TRIE_HASH_LEN},
shale::DiskAddress,
shale::ShaleStore,
shale::{DiskAddress, Storable},
};
use bincode::{Error, Options};
use std::fmt::{Debug, Error as FmtError, Formatter};
use std::{
fmt::{Debug, Error as FmtError, Formatter},
io::{Cursor, Write},
mem::size_of,
ops::Deref,
};

pub type DataLen = u32;
pub type EncodedChildLen = u8;

pub const MAX_CHILDREN: usize = 16;
pub const SIZE: usize = MAX_CHILDREN + 1;
const MAX_CHILDREN: usize = 16;

#[derive(PartialEq, Eq, Clone)]
pub struct BranchNode {
Expand Down Expand Up @@ -50,11 +57,14 @@ impl Debug for BranchNode {
}

impl BranchNode {
pub const MAX_CHILDREN: usize = MAX_CHILDREN;
pub const MSIZE: usize = Self::MAX_CHILDREN + 1;

pub fn new(
_path: PartialPath,
chd: [Option<DiskAddress>; MAX_CHILDREN],
chd: [Option<DiskAddress>; Self::MAX_CHILDREN],
value: Option<Vec<u8>>,
chd_encoded: [Option<Vec<u8>>; MAX_CHILDREN],
chd_encoded: [Option<Vec<u8>>; Self::MAX_CHILDREN],
) -> Self {
BranchNode {
// path,
Expand All @@ -68,19 +78,19 @@ impl BranchNode {
&self.value
}

pub fn chd(&self) -> &[Option<DiskAddress>; MAX_CHILDREN] {
pub fn chd(&self) -> &[Option<DiskAddress>; Self::MAX_CHILDREN] {
&self.children
}

pub fn chd_mut(&mut self) -> &mut [Option<DiskAddress>; MAX_CHILDREN] {
pub fn chd_mut(&mut self) -> &mut [Option<DiskAddress>; Self::MAX_CHILDREN] {
&mut self.children
}

pub fn chd_encode(&self) -> &[Option<Vec<u8>>; MAX_CHILDREN] {
pub fn chd_encode(&self) -> &[Option<Vec<u8>>; Self::MAX_CHILDREN] {
&self.children_encoded
}

pub fn chd_encoded_mut(&mut self) -> &mut [Option<Vec<u8>>; MAX_CHILDREN] {
pub fn chd_encoded_mut(&mut self) -> &mut [Option<Vec<u8>>; Self::MAX_CHILDREN] {
&mut self.children_encoded
}

Expand Down Expand Up @@ -109,7 +119,7 @@ impl BranchNode {
let value = Some(data).filter(|data| !data.is_empty());

// encode all children.
let mut chd_encoded: [Option<Vec<u8>>; MAX_CHILDREN] = Default::default();
let mut chd_encoded: [Option<Vec<u8>>; Self::MAX_CHILDREN] = Default::default();

// we popped the last element, so their should only be NBRANCH items left
for (i, chd) in items.into_iter().enumerate() {
Expand All @@ -122,15 +132,15 @@ impl BranchNode {

Ok(BranchNode::new(
path,
[None; MAX_CHILDREN],
[None; Self::MAX_CHILDREN],
value,
chd_encoded,
))
}

pub(super) fn encode<S: ShaleStore<Node>>(&self, store: &S) -> Vec<u8> {
// TODO: add path to encoded node
let mut list = <[Encoded<Vec<u8>>; MAX_CHILDREN + 1]>::default();
let mut list = <[Encoded<Vec<u8>>; Self::MAX_CHILDREN + 1]>::default();

for (i, c) in self.children.iter().enumerate() {
match c {
Expand Down Expand Up @@ -170,7 +180,7 @@ impl BranchNode {
}

if let Some(Data(val)) = &self.value {
list[MAX_CHILDREN] =
list[Self::MAX_CHILDREN] =
Encoded::Data(bincode::DefaultOptions::new().serialize(val).unwrap());
}

Expand All @@ -179,3 +189,59 @@ impl BranchNode {
.unwrap()
}
}

impl Storable for BranchNode {
fn serialized_len(&self) -> u64 {
let children_len = Self::MAX_CHILDREN as u64 * DiskAddress::MSIZE;
let data_len = optional_data_len::<DataLen, _>(self.value.as_deref());
let children_encoded_len = self.children_encoded.iter().fold(0, |len, child| {
len + optional_data_len::<EncodedChildLen, _>(child.as_ref())
});

children_len + data_len + children_encoded_len
}

fn serialize(&self, to: &mut [u8]) -> Result<(), crate::shale::ShaleError> {
let mut cursor = Cursor::new(to);

for child in &self.children {
let bytes = child.map(|addr| addr.to_le_bytes()).unwrap_or_default();
cursor.write_all(&bytes)?;
}

let (value_len, value) = self
.value
.as_ref()
.map(|val| (val.len() as DataLen, val.deref()))
.unwrap_or((DataLen::MAX, &[]));

cursor.write_all(&value_len.to_le_bytes())?;
cursor.write_all(value)?;

for child_encoded in &self.children_encoded {
let (child_len, child) = child_encoded
.as_ref()
.map(|child| (child.len() as EncodedChildLen, child.as_slice()))
.unwrap_or((EncodedChildLen::MIN, &[]));

cursor.write_all(&child_len.to_le_bytes())?;
cursor.write_all(child)?;
}

Ok(())
}

fn deserialize<T: crate::shale::CachedStore>(
_addr: usize,
_mem: &T,
) -> Result<Self, crate::shale::ShaleError>
where
Self: Sized,
{
todo!()
}
}

fn optional_data_len<Len, T: AsRef<[u8]>>(data: Option<T>) -> u64 {
size_of::<Len>() as u64 + data.as_ref().map_or(0, |data| data.as_ref().len() as u64)
}
6 changes: 2 additions & 4 deletions firewood/src/shale/disk_address.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ impl DerefMut for DiskAddress {
}

impl DiskAddress {
pub(crate) const MSIZE: u64 = size_of::<Self>() as u64;

/// Return a None DiskAddress
pub fn null() -> Self {
DiskAddress(None)
Expand Down Expand Up @@ -160,10 +162,6 @@ impl std::ops::BitAnd<usize> for DiskAddress {
}
}

impl DiskAddress {
const MSIZE: u64 = size_of::<Self>() as u64;
}

impl Storable for DiskAddress {
fn serialized_len(&self) -> u64 {
Self::MSIZE
Expand Down

0 comments on commit 4a4f4ba

Please sign in to comment.