From d9070ea20c251e8962a24552e053faceff903cf4 Mon Sep 17 00:00:00 2001 From: Connor Barr Date: Fri, 22 Mar 2024 18:13:52 +0000 Subject: [PATCH] Better synced rebalance method with spec --- .../sumtree-orderbook/src/sumtree/node.rs | 76 ++++++++++++------- .../src/sumtree/test/test_node.rs | 9 ++- 2 files changed, 52 insertions(+), 33 deletions(-) diff --git a/contracts/sumtree-orderbook/src/sumtree/node.rs b/contracts/sumtree-orderbook/src/sumtree/node.rs index 4609eaa..227d90f 100644 --- a/contracts/sumtree-orderbook/src/sumtree/node.rs +++ b/contracts/sumtree-orderbook/src/sumtree/node.rs @@ -196,13 +196,6 @@ impl TreeNode { } } - /// Adds a given value to an internal node's accumulator - /// - /// Errors if given node is not internal - pub fn add_value(&mut self, value: Uint128) -> ContractResult<()> { - self.set_value(self.get_value().checked_add(value)?) - } - pub fn get_weight(&self) -> u64 { match self.node_type { NodeType::Internal { weight, .. } => weight, @@ -210,6 +203,13 @@ impl TreeNode { } } + /// Adds a given value to an internal node's accumulator + /// + /// Errors if given node is not internal + pub fn add_value(&mut self, value: Uint128) -> ContractResult<()> { + self.set_value(self.get_value().checked_add(value)?) + } + pub fn set_weight(&mut self, new_weight: u64) -> ContractResult<()> { match &mut self.node_type { NodeType::Internal { weight, .. } => { @@ -220,6 +220,18 @@ impl TreeNode { } } + /// Gets the value for a given node. + /// + /// For `Leaf` nodes this is the `value`. + /// + /// For `Internal` nodes this is the `accumulator`. + pub fn get_value(&self) -> Uint128 { + match self.node_type { + NodeType::Leaf { value, .. } => value, + NodeType::Internal { accumulator, .. } => accumulator, + } + } + /// Recalculates the range and accumulated value for a node and propagates it up the tree /// /// Must be an internal node @@ -282,18 +294,6 @@ impl TreeNode { Ok(()) } - /// Gets the value for a given node. - /// - /// For `Leaf` nodes this is the `value`. - /// - /// For `Internal` nodes this is the `accumulator`. - pub fn get_value(&self) -> Uint128 { - match self.node_type { - NodeType::Leaf { value, .. } => value, - NodeType::Internal { accumulator, .. } => accumulator, - } - } - /// Inserts a given node in to the tree /// /// If the node is internal an error is returned. @@ -371,6 +371,8 @@ impl TreeNode { new_node.parent = Some(self.key); new_node.save(storage)?; self.save(storage)?; + #[cfg(test)] + println!("Inserted {} left on {}", new_node, self); self.rebalance(storage)?; return Ok(()); } @@ -385,6 +387,8 @@ impl TreeNode { new_node.parent = Some(self.key); new_node.save(storage)?; self.save(storage)?; + #[cfg(test)] + println!("Inserted {} right on {}", new_node, self); self.rebalance(storage)?; return Ok(()); } @@ -395,6 +399,8 @@ impl TreeNode { new_node.parent = Some(self.key); new_node.save(storage)?; self.save(storage)?; + #[cfg(test)] + println!("Inserted {} right on {}", new_node, self); self.rebalance(storage)?; return Ok(()); } @@ -413,6 +419,11 @@ impl TreeNode { let new_left = left.split(storage, new_node)?; self.left = Some(new_left); self.save(storage)?; + #[cfg(test)] + println!( + "Split {} left on {} generating {}", + new_node, self, new_left + ); self.rebalance(storage)?; return Ok(()); } @@ -425,6 +436,8 @@ impl TreeNode { new_node.parent = Some(self.key); new_node.save(storage)?; self.save(storage)?; + #[cfg(test)] + println!("Inserted {} right on {} after reordering", new_node, self); self.rebalance(storage)?; return Ok(()); } @@ -435,6 +448,11 @@ impl TreeNode { let new_right = right.split(storage, new_node)?; self.right = Some(new_right); self.save(storage)?; + #[cfg(test)] + println!( + "Split {} right on {} generating {}", + new_node, self, new_right + ); self.rebalance(storage)?; return Ok(()); } @@ -521,7 +539,7 @@ impl TreeNode { pub fn get_balance_factor(&self, storage: &dyn Storage) -> ContractResult { let left_weight = self.get_left(storage)?.map_or(0, |n| n.get_weight()); let right_weight = self.get_right(storage)?.map_or(0, |n| n.get_weight()); - Ok(left_weight as i32 - right_weight as i32) + Ok(right_weight as i32 - left_weight as i32) } /// Rebalances the tree starting from the current node. @@ -550,8 +568,8 @@ impl TreeNode { let maybe_right = self.get_right(storage)?; // Determine the direction of imbalance. - let is_right_leaning = balance_factor < 0; - let is_left_leaning = balance_factor > 0; + let is_right_leaning = balance_factor > 0; + let is_left_leaning = balance_factor < 0; // Calculate balance factors for child nodes to determine rotation type. let right_balance_factor = maybe_right @@ -562,14 +580,14 @@ impl TreeNode { .map_or(0, |n| n.get_balance_factor(storage).unwrap_or(0)); // Perform rotations based on the type of imbalance detected. - // Case 1: Left-Left (Right rotation needed) - if is_left_leaning && left_balance_factor >= 0 { - self.rotate_right(storage)?; - } - // Case 2: Right-Right (Left rotation needed) - else if is_right_leaning && right_balance_factor >= 0 { + // Case 1: Right-Right (Right rotation needed) + if is_right_leaning && right_balance_factor >= 0 { self.rotate_left(storage)?; } + // Case 2: Left Left (Right rotation needed) + else if is_left_leaning && left_balance_factor <= 0 { + self.rotate_right(storage)?; + } // Case 3: Right-Left (Right rotation on right child followed by Left rotation on self) else if is_right_leaning && right_balance_factor < 0 { maybe_right.unwrap().rotate_right(storage)?; @@ -577,7 +595,7 @@ impl TreeNode { self.rotate_left(storage)?; } // Case 4: Left-Right (Left rotation on left child followed by Right rotation on self) - else if is_left_leaning && left_balance_factor < 0 { + else if is_left_leaning && left_balance_factor > 0 { maybe_left.unwrap().rotate_left(storage)?; self.sync(storage)?; self.rotate_right(storage)?; diff --git a/contracts/sumtree-orderbook/src/sumtree/test/test_node.rs b/contracts/sumtree-orderbook/src/sumtree/test/test_node.rs index a77359a..ee91e38 100644 --- a/contracts/sumtree-orderbook/src/sumtree/test/test_node.rs +++ b/contracts/sumtree-orderbook/src/sumtree/test/test_node.rs @@ -494,10 +494,10 @@ fn test_tree_rebalancing() { NodeType::leaf(6u32, 1u32), NodeType::leaf(3u32, 1u32), NodeType::leaf(5u32, 1u32), + NodeType::leaf(20u32, 1u32), + NodeType::leaf(13u32, 1u32), NodeType::leaf(4u32, 1u32), NodeType::leaf(7u32, 1u32), - NodeType::leaf(12u32, 1u32), - // NodeType::leaf(20u32, 1u32), ], print: true, }, @@ -511,8 +511,9 @@ fn test_tree_rebalancing() { NodeType::leaf(3u32, 1u32), NodeType::leaf(2u32, 1u32), NodeType::leaf(4u32, 1u32), + NodeType::leaf(12u32, 1u32), ], - print: false, + print: true, }, ]; @@ -538,7 +539,7 @@ fn test_tree_rebalancing() { generate_node_id(deps.as_mut().storage, book_id, tick_id).unwrap(), node, ); - println!("Inserting: {} into {}", tree_node, tree); + // println!("Inserting: {} into {}", tree_node, tree); tree.insert(deps.as_mut().storage, &mut tree_node).unwrap(); tree = get_root_node(deps.as_ref().storage, book_id, tick_id).unwrap();