diff --git a/contracts/sumtree-orderbook/src/sumtree/node.rs b/contracts/sumtree-orderbook/src/sumtree/node.rs index 5ccccb5..0904747 100644 --- a/contracts/sumtree-orderbook/src/sumtree/node.rs +++ b/contracts/sumtree-orderbook/src/sumtree/node.rs @@ -9,7 +9,7 @@ use cosmwasm_schema::cw_serde; use cosmwasm_std::{ensure, Storage, Uint128}; use cw_storage_plus::Map; -use crate::{error::ContractResult, ContractError}; +use crate::{error::ContractResult, sumtree::tree::TREE, ContractError}; pub const NODES: Map<&(u64, i64, u64), TreeNode> = Map::new("nodes"); pub const NODE_ID_COUNTER: Map<&(u64, i64), u64> = Map::new("node_id"); @@ -40,6 +40,8 @@ pub enum NodeType { accumulator: Uint128, // Range from min ETAS to max ETAS + value of max ETAS range: (Uint128, Uint128), + // Amount of child nodes + weight: u64, }, } @@ -58,6 +60,7 @@ impl NodeType { Self::Internal { range: (range.0.into(), range.1.into()), accumulator: accumulator.into(), + weight: 0, } } } @@ -67,6 +70,7 @@ impl Default for NodeType { Self::Internal { accumulator: Uint128::zero(), range: (Uint128::MAX, Uint128::MIN), + weight: 0, } } } @@ -193,6 +197,23 @@ impl TreeNode { self.set_value(self.get_value().checked_add(value)?) } + pub fn get_weight(&self) -> u64 { + match self.node_type { + NodeType::Internal { weight, .. } => weight, + NodeType::Leaf { .. } => 0, + } + } + + pub fn set_weight(&mut self, new_weight: u64) -> ContractResult<()> { + match &mut self.node_type { + NodeType::Internal { weight, .. } => { + *weight = new_weight; + Ok(()) + } + NodeType::Leaf { .. } => Err(ContractError::InvalidNodeType), + } + } + /// Recalculates the range and accumulated value for a node and propagates it up the tree /// /// Must be an internal node @@ -283,6 +304,8 @@ impl TreeNode { self.set_max_range(new_node.get_max_range())?; } + self.set_weight(self.get_weight() + 1)?; + let maybe_left = self.get_left(storage)?; let maybe_right = self.get_right(storage)?; @@ -346,9 +369,12 @@ impl TreeNode { let right_is_leaf = maybe_right .clone() .map_or(false, |right| !right.is_internal()); + let is_higher_than_right_leaf = maybe_right.clone().map_or(false, |r| { + !r.is_internal() && new_node.get_min_range() >= r.get_max_range() + }); // Case 4 - if left_is_leaf { + if left_is_leaf && !is_higher_than_right_leaf { let mut left = maybe_left.unwrap(); let new_left = left.split(storage, new_node)?; self.left = Some(new_left); @@ -358,9 +384,6 @@ impl TreeNode { // Case 5: Reordering // TODO: Add edge case test for this - let is_higher_than_right_leaf = maybe_right.clone().map_or(false, |r| { - !r.is_internal() && new_node.get_min_range() >= r.get_max_range() - }); if is_higher_than_right_leaf && maybe_left.is_none() { self.left = self.right; self.right = Some(new_node.key); @@ -455,6 +478,64 @@ impl TreeNode { Ok(()) } + pub fn rotate_right(&mut self, storage: &mut dyn Storage) -> ContractResult<()> { + let maybe_left = self.get_left(storage)?; + ensure!(maybe_left.is_some(), ContractError::InvalidNodeType); + + let mut left = maybe_left.unwrap(); + + left.parent = self.parent; + self.parent = Some(left.key); + self.left = left.right; + + if let Some(mut right) = left.get_right(storage)? { + right.parent = Some(self.key); + right.save(storage)?; + } + + left.right = Some(self.key); + + left.save(storage)?; + self.save(storage)?; + + self.sync_range_and_value(storage)?; + + if left.parent.is_none() { + TREE.save(storage, &(left.book_id, left.tick_id), &left.key)?; + } + + Ok(()) + } + + pub fn rotate_left(&mut self, storage: &mut dyn Storage) -> ContractResult<()> { + let maybe_right = self.get_right(storage)?; + ensure!(maybe_right.is_some(), ContractError::InvalidNodeType); + + let mut right = maybe_right.unwrap(); + + right.parent = self.parent; + self.parent = Some(right.key); + self.right = right.left; + + if let Some(mut left) = right.get_left(storage)? { + left.parent = Some(self.key); + left.save(storage)?; + } + + right.left = Some(self.key); + + right.save(storage)?; + self.save(storage)?; + + self.sync_range_and_value(storage)?; + + if right.parent.is_none() { + TREE.save(storage, &(right.book_id, right.tick_id), &right.key)?; + } + + Ok(()) + } + #[cfg(test)] /// Depth first search traversal of tree pub fn traverse(&self, storage: &dyn Storage) -> ContractResult> { @@ -515,7 +596,9 @@ impl Display for NodeType { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { NodeType::Leaf { value, etas } => write!(f, "{etas} {value}"), - NodeType::Internal { accumulator, range } => { + NodeType::Internal { + accumulator, range, .. + } => { write!(f, "{} {}-{}", accumulator, range.0, range.1) } } diff --git a/contracts/sumtree-orderbook/src/sumtree/test/test_node.rs b/contracts/sumtree-orderbook/src/sumtree/test/test_node.rs index a5a09c0..cf9be77 100644 --- a/contracts/sumtree-orderbook/src/sumtree/test/test_node.rs +++ b/contracts/sumtree-orderbook/src/sumtree/test/test_node.rs @@ -1,6 +1,9 @@ use cosmwasm_std::{testing::mock_dependencies, Uint128}; -use crate::sumtree::node::{generate_node_id, NodeType, TreeNode, NODES}; +use crate::sumtree::{ + node::{generate_node_id, NodeType, TreeNode, NODES}, + tree::get_root_node, +}; struct TestNodeInsertCase { name: &'static str, @@ -600,3 +603,162 @@ fn test_node_deletion_valid() { } } } + +enum BalanceDirection { + Left, + Right, +} + +struct TreeRebalancingTestCase { + name: &'static str, + nodes: Vec, + // Depth first search ordering of node IDs (Could be improved?) + expected: Vec, + // Whether to print the tree + print: bool, + direction: BalanceDirection, +} + +#[test] +fn test_tree_rebalancing() { + let book_id = 1; + let tick_id = 1; + let test_cases: Vec = vec![ + TreeRebalancingTestCase { + name: "Left heavy tree", + nodes: vec![ + NodeType::leaf(1u32, 1u32), + NodeType::leaf(9u32, 1u32), + NodeType::leaf(6u32, 1u32), + NodeType::leaf(3u32, 1u32), + NodeType::leaf(2u32, 1u32), + NodeType::leaf(4u32, 1u32), + ], + expected: vec![], + print: true, + direction: BalanceDirection::Right, + }, + TreeRebalancingTestCase { + name: "Right heavy tree", + nodes: vec![ + NodeType::leaf(1u32, 1u32), + NodeType::leaf(5u32, 1u32), + NodeType::leaf(9u32, 1u32), + NodeType::leaf(6u32, 1u32), + // NodeType::leaf(3u32, 1u32), + // NodeType::leaf(2u32, 1u32), + // NodeType::leaf(4u32, 1u32), + ], + expected: vec![], + print: true, + direction: BalanceDirection::Left, + }, + ]; + + for test in test_cases { + let mut deps = mock_dependencies(); + let mut tree = TreeNode::new( + book_id, + tick_id, + generate_node_id(deps.as_mut().storage, book_id, tick_id).unwrap(), + NodeType::internal(Uint128::zero(), (u32::MAX, u32::MIN)), + ); + + for node in test.nodes { + let mut tree_node = TreeNode::new( + book_id, + tick_id, + generate_node_id(deps.as_mut().storage, book_id, tick_id).unwrap(), + node, + ); + NODES + .save( + deps.as_mut().storage, + &(book_id, tick_id, tree_node.key), + &tree_node, + ) + .unwrap(); + tree.insert(deps.as_mut().storage, &mut tree_node).unwrap(); + } + + if test.print { + println!("Pre-Rotation Tree: {}", test.name); + println!("--------------------------"); + let nodes = tree.traverse_bfs(deps.as_ref().storage).unwrap(); + for (idx, row) in nodes.iter().enumerate() { + print_tree_row(row.clone(), idx == 0, (nodes.len() - idx - 1) as u32); + } + println!(); + } + + match test.direction { + BalanceDirection::Left => tree.rotate_left(deps.as_mut().storage).unwrap(), + BalanceDirection::Right => tree.rotate_right(deps.as_mut().storage).unwrap(), + } + + let tree = get_root_node(deps.as_ref().storage, book_id, tick_id).unwrap(); + + if test.print { + println!("Post-Rotation Tree: {}", test.name); + println!("--------------------------"); + let nodes = tree.traverse_bfs(deps.as_ref().storage).unwrap(); + for (idx, row) in nodes.iter().enumerate() { + print_tree_row(row.clone(), idx == 0, (nodes.len() - idx - 1) as u32); + } + println!(); + } + + // let result = tree.traverse(deps.as_ref().storage).unwrap(); + + // assert_eq!( + // result, + // test.expected + // .iter() + // .map(|key| NODES + // .load(deps.as_ref().storage, &(book_id, tick_id, *key)) + // .unwrap()) + // .collect::>() + // ); + + // Uncomment post rebalancing implementation + // let internals: Vec<&TreeNode> = result.iter().filter(|x| x.is_internal()).collect(); + // for internal_node in internals { + // let left_node = internal_node.get_left(deps.as_ref().storage).unwrap(); + // let right_node = internal_node.get_right(deps.as_ref().storage).unwrap(); + + // let accumulated_value = left_node + // .clone() + // .map(|x| x.get_value()) + // .unwrap_or_default() + // .checked_add( + // right_node + // .clone() + // .map(|x| x.get_value()) + // .unwrap_or_default(), + // ) + // .unwrap(); + // assert_eq!(internal_node.get_value(), accumulated_value); + + // let min = left_node + // .clone() + // .map(|n| n.get_min_range()) + // .unwrap_or(Uint128::MAX) + // .min( + // right_node + // .clone() + // .map(|n| n.get_min_range()) + // .unwrap_or(Uint128::MAX), + // ); + // let max = left_node + // .map(|n| n.get_max_range()) + // .unwrap_or(Uint128::MIN) + // .max( + // right_node + // .map(|n| n.get_max_range()) + // .unwrap_or(Uint128::MIN), + // ); + // assert_eq!(internal_node.get_min_range(), min); + // assert_eq!(internal_node.get_max_range(), max); + // } + } +} diff --git a/contracts/sumtree-orderbook/src/sumtree/tree.rs b/contracts/sumtree-orderbook/src/sumtree/tree.rs index 3d81696..fcc8d7f 100644 --- a/contracts/sumtree-orderbook/src/sumtree/tree.rs +++ b/contracts/sumtree-orderbook/src/sumtree/tree.rs @@ -1,7 +1,18 @@ +use cosmwasm_std::Storage; use cw_storage_plus::Map; -use super::node::TreeNode; +use crate::error::ContractResult; + +use super::node::{TreeNode, NODES}; -// TODO: REMOVE #[allow(dead_code)] -pub const TREE: Map<&(u64, i64), TreeNode> = Map::new("tree"); +pub const TREE: Map<&(u64, i64), u64> = Map::new("tree"); + +pub fn get_root_node( + storage: &dyn Storage, + book_id: u64, + tick_id: i64, +) -> ContractResult { + let root_id = TREE.load(storage, &(book_id, tick_id))?; + Ok(NODES.load(storage, &(book_id, tick_id, root_id))?) +}