Skip to content

Commit

Permalink
Use more strict types in node-(de)serialization (#372)
Browse files Browse the repository at this point in the history
  • Loading branch information
richardpringle authored Nov 23, 2023
1 parent 0267229 commit b2b4da8
Showing 1 changed file with 98 additions and 63 deletions.
161 changes: 98 additions & 63 deletions firewood/src/merkle/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use sha3::{Digest, Keccak256};
use std::{
fmt::Debug,
io::{Cursor, Read, Write},
mem::size_of,
sync::{
atomic::{AtomicBool, Ordering},
OnceLock,
Expand Down Expand Up @@ -205,15 +206,6 @@ bitflags! {
}

impl Node {
const BRANCH_NODE: u8 = 0;
const LEAF_NODE: u8 = 1;
const EXT_NODE: u8 = 2;

const ROOT_HASH_VALID_BIT: u8 = 0b01;
// TODO: why are these different?
const IS_ENCODED_BIG_VALID: u8 = 0b10;
const LONG_BIT: u8 = 0b100;

pub(super) fn max_branch_node_size() -> u64 {
let max_size: OnceLock<u64> = OnceLock::new();
*max_size.get_or_init(|| {
Expand Down Expand Up @@ -305,6 +297,17 @@ impl Node {
}
}

#[repr(C)]
struct Meta {
root_hash: [u8; TRIE_HASH_LEN],
attrs: NodeAttributes,
is_encoded_longer_than_hash_len: Option<bool>,
}

impl Meta {
const SIZE: usize = size_of::<Self>();
}

mod type_id {
use crate::shale::ShaleError;

Expand Down Expand Up @@ -337,13 +340,11 @@ use type_id::NodeTypeId;

impl Storable for Node {
fn deserialize<T: CachedStore>(addr: usize, mem: &T) -> Result<Self, ShaleError> {
const META_SIZE: usize = TRIE_HASH_LEN + 1 + 1;

let meta_raw =
mem.get_view(addr, META_SIZE as u64)
mem.get_view(addr, Meta::SIZE as u64)
.ok_or(ShaleError::InvalidCacheView {
offset: addr,
size: META_SIZE as u64,
size: Meta::SIZE as u64,
})?;

let attrs = NodeAttributes::from_bits_retain(meta_raw.as_deref()[TRIE_HASH_LEN]);
Expand All @@ -368,45 +369,55 @@ impl Storable for Node {
match meta_raw.as_deref()[TRIE_HASH_LEN + 1].try_into()? {
NodeTypeId::Branch => {
// TODO: add path
// TODO: figure out what this size is?
let branch_header_size = MAX_CHILDREN as u64 * 8 + 4;
let node_raw = mem.get_view(addr + META_SIZE, branch_header_size).ok_or(
let node_raw = mem.get_view(addr + Meta::SIZE, branch_header_size).ok_or(
ShaleError::InvalidCacheView {
offset: addr + META_SIZE,
offset: addr + Meta::SIZE,
size: branch_header_size,
},
)?;

let mut cur = Cursor::new(node_raw.as_deref());
let mut chd = [None; MAX_CHILDREN];
let mut buff = [0; 8];

for chd in chd.iter_mut() {
cur.read_exact(&mut buff)?;
let addr = usize::from_le_bytes(buff);
if addr != 0 {
*chd = Some(DiskAddress::from(addr))
}
}

cur.read_exact(&mut buff[..4])?;

let raw_len =
u32::from_le_bytes(buff[..4].try_into().expect("invalid slice")) as u64;

let value = if raw_len == u32::MAX as u64 {
None
} else {
Some(Data(
mem.get_view(addr + META_SIZE + branch_header_size as usize, raw_len)
mem.get_view(addr + Meta::SIZE + branch_header_size as usize, raw_len)
.ok_or(ShaleError::InvalidCacheView {
offset: addr + META_SIZE + branch_header_size as usize,
offset: addr + Meta::SIZE + branch_header_size as usize,
size: raw_len,
})?
.as_deref(),
))
};

let mut chd_encoded: [Option<Vec<u8>>; MAX_CHILDREN] = Default::default();

let offset = if raw_len == u32::MAX as u64 {
addr + META_SIZE + branch_header_size as usize
addr + Meta::SIZE + branch_header_size as usize
} else {
addr + META_SIZE + branch_header_size as usize + raw_len as usize
addr + Meta::SIZE + branch_header_size as usize + raw_len as usize
};

let mut cur_encoded_len = 0;

for chd_encoded in chd_encoded.iter_mut() {
let mut buff = [0_u8; 1];
let len_raw = mem.get_view(offset + cur_encoded_len, 1).ok_or(
Expand All @@ -415,57 +426,67 @@ impl Storable for Node {
size: 1,
},
)?;

cur = Cursor::new(len_raw.as_deref());
cur.read_exact(&mut buff)?;

let len = buff[0] as u64;
cur_encoded_len += 1;

if len != 0 {
let encoded_raw = mem.get_view(offset + cur_encoded_len, len).ok_or(
ShaleError::InvalidCacheView {
offset: offset + cur_encoded_len,
size: len,
},
)?;

let encoded: Vec<u8> = encoded_raw.as_deref()[0..].to_vec();
*chd_encoded = Some(encoded);
cur_encoded_len += len as usize
}
}

let inner = NodeType::Branch(
BranchNode {
// path: vec![].into(),
children: chd,
value,
children_encoded: chd_encoded,
}
.into(),
);

Ok(Self::new_from_hash(
root_hash,
is_encoded_longer_than_hash_len,
NodeType::Branch(
BranchNode {
// path: vec![].into(),
children: chd,
value,
children_encoded: chd_encoded,
}
.into(),
),
inner,
))
}

NodeTypeId::Extension => {
let ext_header_size = 1 + 8;
let node_raw = mem.get_view(addr + META_SIZE, ext_header_size).ok_or(

let node_raw = mem.get_view(addr + Meta::SIZE, ext_header_size).ok_or(
ShaleError::InvalidCacheView {
offset: addr + META_SIZE,
offset: addr + Meta::SIZE,
size: ext_header_size,
},
)?;

let mut cur = Cursor::new(node_raw.as_deref());
let mut buff = [0; 8];

cur.read_exact(&mut buff[..1])?;
let path_len = buff[0] as u64;

cur.read_exact(&mut buff)?;
let ptr = u64::from_le_bytes(buff);

let nibbles: Vec<u8> = mem
.get_view(addr + META_SIZE + ext_header_size as usize, path_len)
.get_view(addr + Meta::SIZE + ext_header_size as usize, path_len)
.ok_or(ShaleError::InvalidCacheView {
offset: addr + META_SIZE + ext_header_size as usize,
offset: addr + Meta::SIZE + ext_header_size as usize,
size: path_len,
})?
.as_deref()
Expand All @@ -476,13 +497,14 @@ impl Storable for Node {
let (path, _) = PartialPath::decode(&nibbles);

let mut buff = [0_u8; 1];

let encoded_len_raw = mem
.get_view(
addr + META_SIZE + ext_header_size as usize + path_len as usize,
addr + Meta::SIZE + ext_header_size as usize + path_len as usize,
1,
)
.ok_or(ShaleError::InvalidCacheView {
offset: addr + META_SIZE + ext_header_size as usize + path_len as usize,
offset: addr + Meta::SIZE + ext_header_size as usize + path_len as usize,
size: 1,
})?;

Expand All @@ -494,12 +516,12 @@ impl Storable for Node {
let encoded: Option<Vec<u8>> = if encoded_len != 0 {
let emcoded_raw = mem
.get_view(
addr + META_SIZE + ext_header_size as usize + path_len as usize + 1,
addr + Meta::SIZE + ext_header_size as usize + path_len as usize + 1,
encoded_len,
)
.ok_or(ShaleError::InvalidCacheView {
offset: addr
+ META_SIZE
+ Meta::SIZE
+ ext_header_size as usize
+ path_len as usize
+ 1,
Expand All @@ -511,24 +533,22 @@ impl Storable for Node {
None
};

let node = Self::new_from_hash(
root_hash,
is_encoded_longer_than_hash_len,
NodeType::Extension(ExtNode {
path,
child: DiskAddress::from(ptr as usize),
child_encoded: encoded,
}),
);
let inner = NodeType::Extension(ExtNode {
path,
child: DiskAddress::from(ptr as usize),
child_encoded: encoded,
});

let node = Self::new_from_hash(root_hash, is_encoded_longer_than_hash_len, inner);

Ok(node)
}

NodeTypeId::Leaf => {
let leaf_header_size = 1 + 4;
let node_raw = mem.get_view(addr + META_SIZE, leaf_header_size).ok_or(
let node_raw = mem.get_view(addr + Meta::SIZE, leaf_header_size).ok_or(
ShaleError::InvalidCacheView {
offset: addr + META_SIZE,
offset: addr + Meta::SIZE,
size: leaf_header_size,
},
)?;
Expand All @@ -544,11 +564,11 @@ impl Storable for Node {
let data_len = u32::from_le_bytes(buff) as u64;
let remainder = mem
.get_view(
addr + META_SIZE + leaf_header_size as usize,
addr + Meta::SIZE + leaf_header_size as usize,
path_len + data_len,
)
.ok_or(ShaleError::InvalidCacheView {
offset: addr + META_SIZE + leaf_header_size as usize,
offset: addr + Meta::SIZE + leaf_header_size as usize,
size: path_len + data_len,
})?;

Expand All @@ -574,8 +594,7 @@ impl Storable for Node {
}

fn serialized_len(&self) -> u64 {
32 + 1
+ 1
Meta::SIZE as u64
+ match &self.inner {
NodeType::Branch(n) => {
// TODO: add path
Expand Down Expand Up @@ -609,33 +628,39 @@ impl Storable for Node {
fn serialize(&self, to: &mut [u8]) -> Result<(), ShaleError> {
let mut cur = Cursor::new(to);

let mut attrs = 0;
attrs |= match self.root_hash.get() {
let mut attrs = match self.root_hash.get() {
Some(h) => {
cur.write_all(&h.0)?;
Node::ROOT_HASH_VALID_BIT
NodeAttributes::ROOT_HASH_VALID
}
None => {
cur.write_all(&[0; 32])?;
0
NodeAttributes::empty()
}
};
attrs |= match self.is_encoded_longer_than_hash_len.get() {
Some(b) => (if *b { Node::LONG_BIT } else { 0 } | Node::IS_ENCODED_BIG_VALID),
None => 0,
};
cur.write_all(&[attrs]).unwrap();

if let Some(&b) = self.is_encoded_longer_than_hash_len.get() {
attrs.insert(if b {
NodeAttributes::LONG
} else {
NodeAttributes::IS_ENCODED_BIG_VALID
});
}

cur.write_all(&[attrs.bits()]).unwrap();

match &self.inner {
NodeType::Branch(n) => {
// TODO: add path
cur.write_all(&[Self::BRANCH_NODE]).unwrap();
cur.write_all(&[type_id::NodeTypeId::Branch as u8]).unwrap();

for c in n.children.iter() {
cur.write_all(&match c {
Some(p) => p.to_le_bytes(),
None => 0u64.to_le_bytes(),
})?;
}

match &n.value {
Some(val) => {
cur.write_all(&(val.len() as u32).to_le_bytes())?;
Expand All @@ -645,6 +670,7 @@ impl Storable for Node {
cur.write_all(&u32::MAX.to_le_bytes())?;
}
}

// Since child encoding will only be unset after initialization (only used for range proof),
// it is fine to encode its value adjacent to other fields. Same for extention node.
for encoded in n.children_encoded.iter() {
Expand All @@ -656,23 +682,32 @@ impl Storable for Node {
None => cur.write_all(&0u8.to_le_bytes())?,
}
}

Ok(())
}

NodeType::Extension(n) => {
cur.write_all(&[Self::EXT_NODE])?;
cur.write_all(&[type_id::NodeTypeId::Extension as u8])?;

let path: Vec<u8> = from_nibbles(&n.path.encode(false)).collect();

cur.write_all(&[path.len() as u8])?;
cur.write_all(&n.child.to_le_bytes())?;
cur.write_all(&path)?;

if let Some(encoded) = n.chd_encoded() {
cur.write_all(&[encoded.len() as u8])?;
cur.write_all(encoded)?;
}

Ok(())
}

NodeType::Leaf(n) => {
cur.write_all(&[Self::LEAF_NODE])?;
cur.write_all(&[type_id::NodeTypeId::Leaf as u8])?;

let path: Vec<u8> = from_nibbles(&n.path.encode(true)).collect();

cur.write_all(&[path.len() as u8])?;
cur.write_all(&(n.data.len() as u32).to_le_bytes())?;
cur.write_all(&path)?;
Expand Down

0 comments on commit b2b4da8

Please sign in to comment.