Skip to content

Commit

Permalink
Merge pull request #83 from osmosis-labs/alpo/prefix-sum
Browse files Browse the repository at this point in the history
[Sumtree]: Implement prefix sum algorithm
  • Loading branch information
crnbarr93 authored Apr 3, 2024
2 parents a464835 + 58b0c52 commit f89a55e
Show file tree
Hide file tree
Showing 5 changed files with 441 additions and 35 deletions.
43 changes: 22 additions & 21 deletions contracts/sumtree-orderbook/src/sumtree/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -562,12 +562,12 @@ impl TreeNode {
/// have been performed. It checks the balance factor of the current node and performs rotations
/// as necessary to bring the tree back into balance.
pub fn rebalance(&mut self, storage: &mut dyn Storage) -> ContractResult<()> {
ensure!(self.is_internal(), ContractError::InvalidNodeType);
ensure!(self.has_child(), ContractError::ChildlessInternalNode);

// Synchronize the current node's state with storage before rebalancing.
self.sync(storage)?;

ensure!(self.is_internal(), ContractError::InvalidNodeType);
ensure!(self.has_child(), ContractError::ChildlessInternalNode);

// Calculate the balance factor to determine if rebalancing is needed.
let balance_factor = self.get_balance_factor(storage)?;
// Early return if the tree is already balanced.
Expand Down Expand Up @@ -655,15 +655,15 @@ impl TreeNode {
left.save(storage)?;
self.save(storage)?;

// Synchronize the range and value of the current node.
self.sync_range_and_value(storage)?;
left.sync_range_and_value(storage)?;

// If the left node has no parent, it becomes the new root.
if left.parent.is_none() {
TREE.save(storage, &(left.book_id, left.tick_id), &left.key)?;
}

// Synchronize the range and value of the current node.
self.sync_range_and_value(storage)?;
left.sync_range_and_value(storage)?;

// Update the parent's child pointers.
if is_left_child {
let mut parent = maybe_parent.clone().unwrap();
Expand All @@ -685,48 +685,49 @@ impl TreeNode {
/// has a greater height than the left subtree. It adjusts the pointers
/// accordingly to ensure the tree remains a valid binary search tree.
pub fn rotate_left(&mut self, storage: &mut dyn Storage) -> ContractResult<()> {
// Retrieve the parent node, if any, to determine the current node's relationship.
// Retrieve the parent node, if any.
let maybe_parent = self.get_parent(storage)?;

// Determine if the current node is a left or right child of its parent.
let is_left_child = maybe_parent
.clone()
.map_or(false, |p| p.left == Some(self.key));
let is_right_child = maybe_parent
.clone()
.map_or(false, |p| p.right == Some(self.key));

// Ensure the current node has a right child to perform the rotation.
// Ensure the current node has a right child to rotate.
let maybe_right = self.get_right(storage)?;
ensure!(maybe_right.is_some(), ContractError::InvalidNodeType);

// Perform the rotation by reassigning parent and child references.
// Perform the rotation.
let mut right = maybe_right.unwrap();
right.parent = self.parent;
self.parent = Some(right.key);
self.right = right.left;

// Update the parent reference of the new right child, if it exists.
if let Some(mut new_right) = self.get_left(storage)? {
// Update the parent of the new right child, if it exists.
if let Some(mut new_right) = self.get_right(storage)? {
new_right.parent = Some(self.key);
new_right.save(storage)?;
}

// Complete the rotation by setting the left child of the right node to the current node.
// Complete the rotation by setting the right child of the left node to the current node.
right.left = Some(self.key);

// Persist the changes to both nodes.
// Save the changes to both nodes.
right.save(storage)?;
self.save(storage)?;

// Synchronize the range and value of the current node.
self.sync_range_and_value(storage)?;
right.sync_range_and_value(storage)?;

// If the right node has no parent after the rotation, it becomes the new root.
// If the left node has no parent, it becomes the new root.
if right.parent.is_none() {
TREE.save(storage, &(right.book_id, right.tick_id), &right.key)?;
}

// Synchronize the range and value of the current node.
self.sync_range_and_value(storage)?;
right.sync_range_and_value(storage)?;

// Update the child references of the parent node to reflect the rotation.
// Update the parent's child pointers.
if is_left_child {
let mut parent = maybe_parent.clone().unwrap();
parent.left = Some(right.key);
Expand Down
1 change: 1 addition & 0 deletions contracts/sumtree-orderbook/src/sumtree/test/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod test_node;
pub mod test_tree;
85 changes: 76 additions & 9 deletions contracts/sumtree-orderbook/src/sumtree/test/test_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use cosmwasm_std::{testing::mock_dependencies, Decimal256, Deps, Storage, Uint25
use crate::{
sumtree::{
node::{generate_node_id, NodeType, TreeNode, NODES},
tree::{get_root_node, TREE},
tree::{get_prefix_sum, get_root_node, TREE},
},
ContractError,
};
Expand All @@ -18,7 +18,7 @@ struct TestNodeInsertCase {
}

// Asserts all values of internal nodes are as expected
fn assert_internal_values(
pub fn assert_internal_values(
test_name: &'static str,
deps: Deps,
internals: Vec<&TreeNode>,
Expand All @@ -37,7 +37,15 @@ fn assert_internal_values(
.map_or(Decimal256::zero(), |x| x.get_value()),
)
.unwrap();
assert_eq!(internal_node.get_value(), accumulated_value);
assert_eq!(
internal_node.get_value(),
accumulated_value,
"{} failed on internal node value, expected {} got {} for {}",
test_name,
accumulated_value,
internal_node.get_value(),
internal_node.key
);

let min = left_node
.clone()
Expand All @@ -59,8 +67,13 @@ fn assert_internal_values(
assert_eq!(internal_node.get_max_range(), max);

let balance_factor = right_node
.clone()
.map_or(0, |n| n.get_height(deps.storage).unwrap())
.abs_diff(left_node.map_or(0, |n| n.get_height(deps.storage).unwrap()));
.abs_diff(
left_node
.clone()
.map_or(0, |n| n.get_height(deps.storage).unwrap()),
);

assert_eq!(
internal_node.get_weight(),
Expand All @@ -78,6 +91,39 @@ fn assert_internal_values(
internal_node.key
);
}

if let Some(left) = left_node {
let parent_string = if let Some(parent) = left.parent {
parent.to_string()
} else {
"None".to_string()
};
assert_eq!(
left.parent,
Some(internal_node.key),
"{} - Child {} does not have correct parent: expected {}, received {}",
test_name,
left,
internal_node.key,
parent_string
);
}
if let Some(right) = right_node {
let parent_string = if let Some(parent) = right.parent {
parent.to_string()
} else {
"None".to_string()
};
assert_eq!(
right.parent,
Some(internal_node.key),
"{} - Child {} does not have correct parent: expected {}, received {}",
test_name,
right,
internal_node.key,
parent_string
);
}
}
}

Expand Down Expand Up @@ -1740,10 +1786,10 @@ fn test_rebalance() {
.with_parent(2),
// Left-Left-Left
TreeNode::new(book_id, tick_id, 6, NodeType::leaf_uint256(2u32, 1u32))
.with_parent(2),
.with_parent(4),
// Left-Left-Right
TreeNode::new(book_id, tick_id, 7, NodeType::leaf_uint256(3u32, 1u32))
.with_parent(2),
.with_parent(4),
// Left-Right
TreeNode::new(book_id, tick_id, 5, NodeType::leaf_uint256(4u32, 1u32))
.with_parent(4),
Expand Down Expand Up @@ -2060,11 +2106,13 @@ fn generate_nodes(
tick_id: i64,
quantity: u32,
) -> Vec<TreeNode> {
use rand::seq::SliceRandom;
use rand::thread_rng;
use rand::rngs::StdRng;
use rand::{seq::SliceRandom, SeedableRng};

let mut range: Vec<u32> = (0..quantity).collect();
range.shuffle(&mut thread_rng());
let seed = [0u8; 32]; // A fixed seed for deterministic randomness
let mut rng = StdRng::from_seed(seed);
range.shuffle(&mut rng);

let mut nodes = vec![];
for val in range {
Expand Down Expand Up @@ -2093,22 +2141,41 @@ fn test_node_insert_large_quantity() {
NodeType::internal_uint256(0u32, (u32::MAX, u32::MIN)),
);

TREE.save(deps.as_mut().storage, &(book_id, tick_id), &tree.key)
.unwrap();

let nodes = generate_nodes(deps.as_mut().storage, book_id, tick_id, 1000);

let target_etas = Decimal256::from_ratio(536u128, 1u128);
let mut expected_prefix_sum = Decimal256::zero();
let nodes_count = nodes.len();

// Insert nodes into tree
for mut node in nodes {
NODES
.save(deps.as_mut().storage, &(book_id, tick_id, node.key), &node)
.unwrap();
tree.insert(deps.as_mut().storage, &mut node).unwrap();
tree = get_root_node(deps.as_ref().storage, book_id, tick_id).unwrap();
// Track insertions that fall below our target ETAS
if node.get_min_range() < target_etas {
expected_prefix_sum = expected_prefix_sum.checked_add(Decimal256::one()).unwrap();
}
}

// Return tree in vector form from Depth First Search
let result = tree.traverse(deps.as_ref().storage).unwrap();

// Ensure all internal nodes are correctly summed and contain correct ranges
let internals: Vec<&TreeNode> = result.iter().filter(|x| x.is_internal()).collect();
let leaves: Vec<&TreeNode> = result.iter().filter(|x| !x.is_internal()).collect();
assert_internal_values("Large amount of nodes", deps.as_ref(), internals, true);

// Ensure prefix sum functions correctly
let root_node = get_root_node(deps.as_mut().storage, book_id, tick_id).unwrap();

let prefix_sum = get_prefix_sum(deps.as_mut().storage, root_node, target_etas).unwrap();
assert_eq!(expected_prefix_sum, prefix_sum);
}

const SPACING: u32 = 2u32;
Expand Down
Loading

0 comments on commit f89a55e

Please sign in to comment.