From b2b4da89947b98ca7a4b4ddb7453847a01a22c06 Mon Sep 17 00:00:00 2001 From: Richard Pringle Date: Wed, 22 Nov 2023 19:23:22 -0500 Subject: [PATCH] Use more strict types in node-(de)serialization (#372) --- firewood/src/merkle/node.rs | 161 ++++++++++++++++++++++-------------- 1 file changed, 98 insertions(+), 63 deletions(-) diff --git a/firewood/src/merkle/node.rs b/firewood/src/merkle/node.rs index a7fe8bf34..b0d5c4540 100644 --- a/firewood/src/merkle/node.rs +++ b/firewood/src/merkle/node.rs @@ -10,6 +10,7 @@ use sha3::{Digest, Keccak256}; use std::{ fmt::Debug, io::{Cursor, Read, Write}, + mem::size_of, sync::{ atomic::{AtomicBool, Ordering}, OnceLock, @@ -205,15 +206,6 @@ bitflags! { } impl Node { - const BRANCH_NODE: u8 = 0; - const LEAF_NODE: u8 = 1; - const EXT_NODE: u8 = 2; - - const ROOT_HASH_VALID_BIT: u8 = 0b01; - // TODO: why are these different? - const IS_ENCODED_BIG_VALID: u8 = 0b10; - const LONG_BIT: u8 = 0b100; - pub(super) fn max_branch_node_size() -> u64 { let max_size: OnceLock = OnceLock::new(); *max_size.get_or_init(|| { @@ -305,6 +297,17 @@ impl Node { } } +#[repr(C)] +struct Meta { + root_hash: [u8; TRIE_HASH_LEN], + attrs: NodeAttributes, + is_encoded_longer_than_hash_len: Option, +} + +impl Meta { + const SIZE: usize = size_of::(); +} + mod type_id { use crate::shale::ShaleError; @@ -337,13 +340,11 @@ use type_id::NodeTypeId; impl Storable for Node { fn deserialize(addr: usize, mem: &T) -> Result { - const META_SIZE: usize = TRIE_HASH_LEN + 1 + 1; - let meta_raw = - mem.get_view(addr, META_SIZE as u64) + mem.get_view(addr, Meta::SIZE as u64) .ok_or(ShaleError::InvalidCacheView { offset: addr, - size: META_SIZE as u64, + size: Meta::SIZE as u64, })?; let attrs = NodeAttributes::from_bits_retain(meta_raw.as_deref()[TRIE_HASH_LEN]); @@ -368,16 +369,19 @@ impl Storable for Node { match meta_raw.as_deref()[TRIE_HASH_LEN + 1].try_into()? { NodeTypeId::Branch => { // TODO: add path + // TODO: figure out what this size is? let branch_header_size = MAX_CHILDREN as u64 * 8 + 4; - let node_raw = mem.get_view(addr + META_SIZE, branch_header_size).ok_or( + let node_raw = mem.get_view(addr + Meta::SIZE, branch_header_size).ok_or( ShaleError::InvalidCacheView { - offset: addr + META_SIZE, + offset: addr + Meta::SIZE, size: branch_header_size, }, )?; + let mut cur = Cursor::new(node_raw.as_deref()); let mut chd = [None; MAX_CHILDREN]; let mut buff = [0; 8]; + for chd in chd.iter_mut() { cur.read_exact(&mut buff)?; let addr = usize::from_le_bytes(buff); @@ -385,28 +389,35 @@ impl Storable for Node { *chd = Some(DiskAddress::from(addr)) } } + cur.read_exact(&mut buff[..4])?; + let raw_len = u32::from_le_bytes(buff[..4].try_into().expect("invalid slice")) as u64; + let value = if raw_len == u32::MAX as u64 { None } else { Some(Data( - mem.get_view(addr + META_SIZE + branch_header_size as usize, raw_len) + mem.get_view(addr + Meta::SIZE + branch_header_size as usize, raw_len) .ok_or(ShaleError::InvalidCacheView { - offset: addr + META_SIZE + branch_header_size as usize, + offset: addr + Meta::SIZE + branch_header_size as usize, size: raw_len, })? .as_deref(), )) }; + let mut chd_encoded: [Option>; MAX_CHILDREN] = Default::default(); + let offset = if raw_len == u32::MAX as u64 { - addr + META_SIZE + branch_header_size as usize + addr + Meta::SIZE + branch_header_size as usize } else { - addr + META_SIZE + branch_header_size as usize + raw_len as usize + addr + Meta::SIZE + branch_header_size as usize + raw_len as usize }; + let mut cur_encoded_len = 0; + for chd_encoded in chd_encoded.iter_mut() { let mut buff = [0_u8; 1]; let len_raw = mem.get_view(offset + cur_encoded_len, 1).ok_or( @@ -415,10 +426,13 @@ impl Storable for Node { size: 1, }, )?; + cur = Cursor::new(len_raw.as_deref()); cur.read_exact(&mut buff)?; + let len = buff[0] as u64; cur_encoded_len += 1; + if len != 0 { let encoded_raw = mem.get_view(offset + cur_encoded_len, len).ok_or( ShaleError::InvalidCacheView { @@ -426,46 +440,53 @@ impl Storable for Node { size: len, }, )?; + let encoded: Vec = encoded_raw.as_deref()[0..].to_vec(); *chd_encoded = Some(encoded); cur_encoded_len += len as usize } } + let inner = NodeType::Branch( + BranchNode { + // path: vec![].into(), + children: chd, + value, + children_encoded: chd_encoded, + } + .into(), + ); + Ok(Self::new_from_hash( root_hash, is_encoded_longer_than_hash_len, - NodeType::Branch( - BranchNode { - // path: vec![].into(), - children: chd, - value, - children_encoded: chd_encoded, - } - .into(), - ), + inner, )) } NodeTypeId::Extension => { let ext_header_size = 1 + 8; - let node_raw = mem.get_view(addr + META_SIZE, ext_header_size).ok_or( + + let node_raw = mem.get_view(addr + Meta::SIZE, ext_header_size).ok_or( ShaleError::InvalidCacheView { - offset: addr + META_SIZE, + offset: addr + Meta::SIZE, size: ext_header_size, }, )?; + let mut cur = Cursor::new(node_raw.as_deref()); let mut buff = [0; 8]; + cur.read_exact(&mut buff[..1])?; let path_len = buff[0] as u64; + cur.read_exact(&mut buff)?; let ptr = u64::from_le_bytes(buff); let nibbles: Vec = mem - .get_view(addr + META_SIZE + ext_header_size as usize, path_len) + .get_view(addr + Meta::SIZE + ext_header_size as usize, path_len) .ok_or(ShaleError::InvalidCacheView { - offset: addr + META_SIZE + ext_header_size as usize, + offset: addr + Meta::SIZE + ext_header_size as usize, size: path_len, })? .as_deref() @@ -476,13 +497,14 @@ impl Storable for Node { let (path, _) = PartialPath::decode(&nibbles); let mut buff = [0_u8; 1]; + let encoded_len_raw = mem .get_view( - addr + META_SIZE + ext_header_size as usize + path_len as usize, + addr + Meta::SIZE + ext_header_size as usize + path_len as usize, 1, ) .ok_or(ShaleError::InvalidCacheView { - offset: addr + META_SIZE + ext_header_size as usize + path_len as usize, + offset: addr + Meta::SIZE + ext_header_size as usize + path_len as usize, size: 1, })?; @@ -494,12 +516,12 @@ impl Storable for Node { let encoded: Option> = if encoded_len != 0 { let emcoded_raw = mem .get_view( - addr + META_SIZE + ext_header_size as usize + path_len as usize + 1, + addr + Meta::SIZE + ext_header_size as usize + path_len as usize + 1, encoded_len, ) .ok_or(ShaleError::InvalidCacheView { offset: addr - + META_SIZE + + Meta::SIZE + ext_header_size as usize + path_len as usize + 1, @@ -511,24 +533,22 @@ impl Storable for Node { None }; - let node = Self::new_from_hash( - root_hash, - is_encoded_longer_than_hash_len, - NodeType::Extension(ExtNode { - path, - child: DiskAddress::from(ptr as usize), - child_encoded: encoded, - }), - ); + let inner = NodeType::Extension(ExtNode { + path, + child: DiskAddress::from(ptr as usize), + child_encoded: encoded, + }); + + let node = Self::new_from_hash(root_hash, is_encoded_longer_than_hash_len, inner); Ok(node) } NodeTypeId::Leaf => { let leaf_header_size = 1 + 4; - let node_raw = mem.get_view(addr + META_SIZE, leaf_header_size).ok_or( + let node_raw = mem.get_view(addr + Meta::SIZE, leaf_header_size).ok_or( ShaleError::InvalidCacheView { - offset: addr + META_SIZE, + offset: addr + Meta::SIZE, size: leaf_header_size, }, )?; @@ -544,11 +564,11 @@ impl Storable for Node { let data_len = u32::from_le_bytes(buff) as u64; let remainder = mem .get_view( - addr + META_SIZE + leaf_header_size as usize, + addr + Meta::SIZE + leaf_header_size as usize, path_len + data_len, ) .ok_or(ShaleError::InvalidCacheView { - offset: addr + META_SIZE + leaf_header_size as usize, + offset: addr + Meta::SIZE + leaf_header_size as usize, size: path_len + data_len, })?; @@ -574,8 +594,7 @@ impl Storable for Node { } fn serialized_len(&self) -> u64 { - 32 + 1 - + 1 + Meta::SIZE as u64 + match &self.inner { NodeType::Branch(n) => { // TODO: add path @@ -609,33 +628,39 @@ impl Storable for Node { fn serialize(&self, to: &mut [u8]) -> Result<(), ShaleError> { let mut cur = Cursor::new(to); - let mut attrs = 0; - attrs |= match self.root_hash.get() { + let mut attrs = match self.root_hash.get() { Some(h) => { cur.write_all(&h.0)?; - Node::ROOT_HASH_VALID_BIT + NodeAttributes::ROOT_HASH_VALID } None => { cur.write_all(&[0; 32])?; - 0 + NodeAttributes::empty() } }; - attrs |= match self.is_encoded_longer_than_hash_len.get() { - Some(b) => (if *b { Node::LONG_BIT } else { 0 } | Node::IS_ENCODED_BIG_VALID), - None => 0, - }; - cur.write_all(&[attrs]).unwrap(); + + if let Some(&b) = self.is_encoded_longer_than_hash_len.get() { + attrs.insert(if b { + NodeAttributes::LONG + } else { + NodeAttributes::IS_ENCODED_BIG_VALID + }); + } + + cur.write_all(&[attrs.bits()]).unwrap(); match &self.inner { NodeType::Branch(n) => { // TODO: add path - cur.write_all(&[Self::BRANCH_NODE]).unwrap(); + cur.write_all(&[type_id::NodeTypeId::Branch as u8]).unwrap(); + for c in n.children.iter() { cur.write_all(&match c { Some(p) => p.to_le_bytes(), None => 0u64.to_le_bytes(), })?; } + match &n.value { Some(val) => { cur.write_all(&(val.len() as u32).to_le_bytes())?; @@ -645,6 +670,7 @@ impl Storable for Node { cur.write_all(&u32::MAX.to_le_bytes())?; } } + // Since child encoding will only be unset after initialization (only used for range proof), // it is fine to encode its value adjacent to other fields. Same for extention node. for encoded in n.children_encoded.iter() { @@ -656,23 +682,32 @@ impl Storable for Node { None => cur.write_all(&0u8.to_le_bytes())?, } } + Ok(()) } + NodeType::Extension(n) => { - cur.write_all(&[Self::EXT_NODE])?; + cur.write_all(&[type_id::NodeTypeId::Extension as u8])?; + let path: Vec = from_nibbles(&n.path.encode(false)).collect(); + cur.write_all(&[path.len() as u8])?; cur.write_all(&n.child.to_le_bytes())?; cur.write_all(&path)?; + if let Some(encoded) = n.chd_encoded() { cur.write_all(&[encoded.len() as u8])?; cur.write_all(encoded)?; } + Ok(()) } + NodeType::Leaf(n) => { - cur.write_all(&[Self::LEAF_NODE])?; + cur.write_all(&[type_id::NodeTypeId::Leaf as u8])?; + let path: Vec = from_nibbles(&n.path.encode(true)).collect(); + cur.write_all(&[path.len() as u8])?; cur.write_all(&(n.data.len() as u32).to_le_bytes())?; cur.write_all(&path)?;