diff --git a/firewood/src/db.rs b/firewood/src/db.rs index de1ff059c..09a0285eb 100644 --- a/firewood/src/db.rs +++ b/firewood/src/db.rs @@ -1,12 +1,6 @@ // Copyright (C) 2023, Ava Labs, Inc. All rights reserved. // See the file LICENSE.md for licensing terms. -use crate::shale::{ - self, - compact::{CompactSpace, CompactSpaceHeader}, - disk_address::DiskAddress, - CachedStore, Obj, ShaleError, ShaleStore, SpaceId, Storable, StoredView, -}; pub use crate::{ config::{DbConfig, DbRevConfig}, storage::{buffer::DiskBufferConfig, WalConfig}, @@ -23,8 +17,18 @@ use crate::{ }, v2::api::{self, HashKey, KeyType, Proof, ValueType}, }; +use crate::{ + merkle, + shale::{ + self, + compact::{CompactSpace, CompactSpaceHeader}, + disk_address::DiskAddress, + CachedStore, Obj, ShaleError, ShaleStore, SpaceId, Storable, StoredView, + }, +}; use async_trait::async_trait; use bytemuck::{cast_slice, AnyBitPattern}; + use metered::metered; use parking_lot::{Mutex, RwLock}; use std::{ @@ -312,6 +316,16 @@ impl + Send + Sync> api::DbView for DbRev { } impl + Send + Sync> DbRev { + pub fn stream( + &self, + start_key: Option, + ) -> Result, api::Error> { + // TODO: get first key when start_key is None + self.merkle + .get_iter(start_key, self.header.kv_root) + .map_err(|e| api::Error::InternalError(e.into())) + } + fn flush_dirty(&mut self) -> Option<()> { self.header.flush_dirty(); self.merkle.flush_dirty()?; diff --git a/firewood/src/merkle.rs b/firewood/src/merkle.rs index f1121ce7e..829978c64 100644 --- a/firewood/src/merkle.rs +++ b/firewood/src/merkle.rs @@ -2,7 +2,9 @@ // See the file LICENSE.md for licensing terms. use crate::shale::{self, disk_address::DiskAddress, ObjWriteError, ShaleError, ShaleStore}; +use crate::v2::api; use crate::{nibbles::Nibbles, v2::api::Proof}; +use futures::Stream; use sha3::Digest; use std::{ cmp::Ordering, @@ -10,6 +12,7 @@ use std::{ io::Write, iter::once, sync::{atomic::Ordering::Relaxed, OnceLock}, + task::Poll, }; use thiserror::Error; @@ -1175,6 +1178,220 @@ impl + Send + Sync> Merkle { pub fn flush_dirty(&self) -> Option<()> { self.store.flush_dirty() } + + pub fn get_iter>( + &self, + key: Option, + root: DiskAddress, + ) -> Result, MerkleError> { + // TODO: if DiskAddress::is_null() ... + Ok(MerkleKeyValueStream { + key_state: IteratorState::new(key), + merkle_root: root, + merkle: self, + }) + } +} + +enum IteratorState<'a> { + /// Start iterating at the beginning of the trie, + /// returning the lowest key/value pair first + StartAtBeginning, + /// Start iterating at the specified key + StartAtKey(Vec), + /// Continue iterating after the given last_node and parents + Iterating { + last_node: ObjRef<'a>, + parents: Vec<(ObjRef<'a>, u8)>, + }, +} +impl IteratorState<'_> { + fn new>(starting: Option) -> Self { + match starting { + None => Self::StartAtBeginning, + Some(key) => Self::StartAtKey(key.as_ref().to_vec()), + } + } +} + +// The default state is to start at the beginning +impl<'a> Default for IteratorState<'a> { + fn default() -> Self { + Self::StartAtBeginning + } +} +pub struct MerkleKeyValueStream<'a, S> { + key_state: IteratorState<'a>, + merkle_root: DiskAddress, + merkle: &'a Merkle, +} + +impl<'a, S: shale::ShaleStore + Send + Sync> Stream for MerkleKeyValueStream<'a, S> { + type Item = Result<(Vec, Vec), api::Error>; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + // Note that this sets the key_state to StartAtBeginning temporarily + // - if you get to the end you get Ok(None) but can fetch again from the start + // - if you get an error, you'll get Err(...), but continuing to fetch starts from the top + // If this isn't what you want, then consider using [std::iter::fuse] + let found_key = match std::mem::take(&mut self.key_state) { + IteratorState::StartAtBeginning => todo!(), + IteratorState::StartAtKey(key) => { + // TODO: support finding the next key after K + let root_node = self + .merkle + .get_node(self.merkle_root) + .map_err(|e| api::Error::InternalError(e.into()))?; + + let (found_node, parents) = self + .merkle + .get_node_and_parents_by_key(root_node, &key) + .map_err(|e| api::Error::InternalError(e.into()))?; + + let Some(last_node) = found_node else { + return Poll::Ready(None); + }; + + let returned_key_value = match last_node.inner() { + NodeType::Branch(branch) => (key, branch.value.to_owned().unwrap().to_vec()), + NodeType::Leaf(leaf) => (key, leaf.1.to_vec()), + NodeType::Extension(_) => todo!(), + }; + + self.key_state = IteratorState::Iterating { last_node, parents }; + + return Poll::Ready(Some(Ok(returned_key_value))); + } + IteratorState::Iterating { + last_node, + mut parents, + } => { + match last_node.inner() { + NodeType::Branch(branch) => { + // previously rendered the value from a branch node, so walk down to the first available child + if let Some(child_position) = + branch.children.iter().position(|&addr| addr.is_some()) + { + let child_address = branch.children[child_position].unwrap(); + + parents.push((last_node, child_position as u8)); // remember where we walked down from + + let current_node = self + .merkle + .get_node(child_address) + .map_err(|e| api::Error::InternalError(e.into()))?; + + let found_key = parents[1..] + .chunks_exact(2) + .map(|parents| (parents[0].1 << 4) + parents[1].1) + .collect::>(); + + self.key_state = IteratorState::Iterating { + // continue iterating from here + last_node: current_node, + parents, + }; + + found_key + } else { + // Branch node with no children? + return Poll::Ready(Some(Err(api::Error::InternalError(Box::new( + MerkleError::ParentLeafBranch, + ))))); + } + } + NodeType::Leaf(leaf) => { + let mut next = parents.pop(); + loop { + match next { + None => return Poll::Ready(None), + Some((parent, child_position)) => { + // Assume all parents are branch nodes + let children = parent.inner().as_branch().unwrap().chd(); + + // we use wrapping_add here because the value might be u8::MAX indicating that + // we want to go down branch + let mut child_position = child_position.wrapping_add(1); + + if let Some(found_offset) = children[child_position as usize..] + .iter() + .position(|&addr| addr.is_some()) + { + child_position += found_offset as u8; + } else { + next = parents.pop(); + continue; + } + + let addr = children[child_position as usize].unwrap(); + + // we push (node, u8::MAX) because we will add 1, which will wrap to 0 + let child = self + .merkle + .get_node(addr) + .map(|node| (node, u8::MAX)) + .map_err(|e| api::Error::InternalError(e.into()))?; + + // TODO: If the branch has a value, then we shouldn't keep_going + let keep_going_down = child.0.inner().is_branch(); + + next = Some(child); + + parents.push((parent, child_position)); + + if !keep_going_down { + break; + } + } + } + } + // recompute current_key + // TODO: Can we keep current_key updated as we walk the tree instead of building it from the top all the time? + let mut current_key = parents[1..] + .chunks_exact(2) + .map(|parents| (parents[0].1 << 4) + parents[1].1) + .collect::>(); + + current_key.extend(leaf.0.to_vec()); + + self.key_state = IteratorState::Iterating { + last_node: next.unwrap().0, + parents, + }; + + current_key + } + + NodeType::Extension(_) => todo!(), + } + } + }; + + // figure out the value to return from the state + // if we get here, we're sure to have something to return + // TODO: It's possible to return a reference to the data since the last_node is + // saved in the iterator + let return_value = match &self.key_state { + IteratorState::Iterating { + last_node, + parents: _, + } => { + let value = match last_node.inner() { + NodeType::Branch(branch) => branch.value.to_owned().unwrap().to_vec(), + NodeType::Leaf(leaf) => leaf.1.to_vec(), + NodeType::Extension(_) => todo!(), + }; + + (found_key, value) + } + _ => unreachable!(), + }; + + Poll::Ready(Some(Ok(return_value))) + } } fn set_parent(new_chd: DiskAddress, parents: &mut [(ObjRef, u8)]) { @@ -1270,10 +1487,12 @@ pub fn from_nibbles(nibbles: &[u8]) -> impl Iterator + '_ { #[cfg(test)] mod tests { use super::*; + use futures::StreamExt; use node::tests::{extension, leaf}; use shale::{cached::DynamicMem, compact::CompactSpace, CachedStore}; use std::sync::Arc; use test_case::test_case; + //use itertools::Itertools; #[test_case(vec![0x12, 0x34, 0x56], vec![0x1, 0x2, 0x3, 0x4, 0x5, 0x6])] #[test_case(vec![0xc0, 0xff], vec![0xc, 0x0, 0xf, 0xf])] @@ -1388,6 +1607,33 @@ mod tests { } } + #[tokio::test] + async fn iterate_empty() { + let merkle = create_test_merkle(); + let root = merkle.init_root().unwrap(); + let mut it = merkle.get_iter(Some(b"x"), root).unwrap(); + let next = it.next().await; + assert!(next.is_none()) + } + + #[tokio::test] + async fn iterate_many() { + let mut merkle = create_test_merkle(); + let root = merkle.init_root().unwrap(); + + for k in u8::MIN..=u8::MAX { + merkle.insert([k], vec![k], root).unwrap(); + } + + let mut it = merkle.get_iter(Some([u8::MIN]), root).unwrap(); + for k in u8::MIN..=u8::MAX { + let next = it.next().await.unwrap().unwrap(); + assert_eq!(next.0, next.1); + assert_eq!(next.1, vec![k]); + } + assert!(it.next().await.is_none()) + } + #[test] fn remove_one() { let key = b"hello";