diff --git a/contracts/sumtree-orderbook/src/sumtree/node.rs b/contracts/sumtree-orderbook/src/sumtree/node.rs index b497389..8f65a5d 100644 --- a/contracts/sumtree-orderbook/src/sumtree/node.rs +++ b/contracts/sumtree-orderbook/src/sumtree/node.rs @@ -115,6 +115,18 @@ impl TreeNode { } } + pub fn get_parent(&self, storage: &dyn Storage) -> ContractResult> { + if let Some(parent) = self.parent { + Ok(NODES.may_load(storage, &(self.book_id, self.tick_id, parent))?) + } else { + Ok(None) + } + } + + pub fn has_child(&self) -> bool { + self.left.is_some() || self.right.is_some() + } + pub fn save(&self, storage: &mut dyn Storage) -> ContractResult<()> { Ok(NODES.save(storage, &(self.book_id, self.tick_id, self.key), self)?) } @@ -161,19 +173,75 @@ impl TreeNode { } } - /// Adds a given value to an internal node's accumulator - /// - /// Errors if given node is not internal - pub fn add_value(&mut self, value: Uint128) -> ContractResult<()> { + pub fn set_value(&mut self, value: Uint128) -> ContractResult<()> { match &mut self.node_type { NodeType::Internal { accumulator, .. } => { - *accumulator = accumulator.checked_add(value)?; + *accumulator = value; Ok(()) } NodeType::Leaf { .. } => Err(ContractError::InvalidNodeType), } } + /// Adds a given value to an internal node's accumulator + /// + /// Errors if given node is not internal + pub fn add_value(&mut self, value: Uint128) -> ContractResult<()> { + self.set_value(self.get_value().checked_add(value)?) + } + + // TODO: This can likely be optimized + /// Recalculates the range and accumulated value for a node and propagates it up the tree + pub fn recalculate_values(&mut self, storage: &mut dyn Storage) -> ContractResult<()> { + let maybe_left = self.get_left(storage)?; + let maybe_right = self.get_right(storage)?; + + // Calculate min from remaining children + // Attempt to get min value or default to Uint128::MAX for both nodes + // Take min from both returned values + let min = maybe_left + .clone() + .map(|n| n.get_min_range()) + .unwrap_or(Uint128::MAX) + .min( + maybe_right + .clone() + .map(|n| n.get_min_range()) + .unwrap_or(Uint128::MAX), + ); + // Calculate max from remaining children + // Attempt to get max value or default to Uint128::MIN for both nodes + // Take max from both returned values + let max = maybe_left + .clone() + .map(|n| n.get_max_range()) + .unwrap_or(Uint128::MIN) + .max( + maybe_right + .clone() + .map(|n| n.get_max_range()) + .unwrap_or(Uint128::MIN), + ); + + self.set_min_range(min)?; + self.set_max_range(max)?; + + let value = maybe_left + .map(|n| n.get_value()) + .unwrap_or_default() + .checked_add(maybe_right.map(|n| n.get_value()).unwrap_or_default())?; + self.set_value(value)?; + + // Must save before propagating as parent will read this node + self.save(storage)?; + + if let Some(mut parent) = self.get_parent(storage)? { + parent.recalculate_values(storage)?; + } + + Ok(()) + } + /// Gets the value for a given node. /// /// For `Leaf` nodes this is the `value`. @@ -322,10 +390,11 @@ impl TreeNode { ); // Save new key references - self.parent = Some(id); - new_node.parent = Some(id); + new_parent.parent = self.parent; new_parent.left = Some(new_left); new_parent.right = Some(new_right); + self.parent = Some(id); + new_node.parent = Some(id); new_parent.save(storage)?; self.save(storage)?; @@ -334,6 +403,30 @@ impl TreeNode { Ok(id) } + pub fn delete(&self, storage: &mut dyn Storage) -> ContractResult<()> { + let maybe_parent = self.get_parent(storage)?; + if let Some(mut parent) = maybe_parent { + // Remove node reference from parent + if parent.left == Some(self.key) { + parent.left = None; + } else if parent.right == Some(self.key) { + parent.right = None; + } + + if !parent.has_child() { + // Remove no-leaf parents + parent.delete(storage)?; + } else { + // Update parents values after removing node + parent.recalculate_values(storage)?; + } + } + + NODES.remove(storage, &(self.book_id, self.tick_id, self.key)); + + Ok(()) + } + #[cfg(test)] /// Depth first search traversal of tree pub fn traverse(&self, storage: &dyn Storage) -> ContractResult> { diff --git a/contracts/sumtree-orderbook/src/sumtree/test/test_node.rs b/contracts/sumtree-orderbook/src/sumtree/test/test_node.rs index 3e02f0c..e7ed41c 100644 --- a/contracts/sumtree-orderbook/src/sumtree/test/test_node.rs +++ b/contracts/sumtree-orderbook/src/sumtree/test/test_node.rs @@ -189,3 +189,174 @@ fn test_node_insert_valid() { } } } + +struct NodeDeletionTestCase { + name: &'static str, + nodes: Vec, + delete: Vec, + // Depth first search ordering of node IDs (Could be improved?) + expected: Vec, + // Whether to print the tree + print: bool, +} + +#[test] +fn test_node_deletion_valid() { + let book_id = 1; + let tick_id = 1; + let test_cases: Vec = vec![ + NodeDeletionTestCase { + name: "Remove only node", + nodes: vec![NodeType::leaf(1u32, 10u32)], + delete: vec![2], + expected: vec![], + print: true, + }, + NodeDeletionTestCase { + name: "Remove one of two nodes", + nodes: vec![NodeType::leaf(1u32, 10u32), NodeType::leaf(11u32, 5u32)], + delete: vec![2], + expected: vec![1, 3], + print: true, + }, + NodeDeletionTestCase { + name: "Remove nested node", + nodes: vec![ + NodeType::leaf(1u32, 10u32), + NodeType::leaf(21u32, 5u32), + NodeType::leaf(11u32, 10u32), + ], + delete: vec![2], + expected: vec![1, 5, 4, 3], + print: true, + }, + NodeDeletionTestCase { + name: "Remove both children of internal", + nodes: vec![ + NodeType::leaf(1u32, 10u32), + NodeType::leaf(21u32, 5u32), + NodeType::leaf(11u32, 10u32), + ], + delete: vec![2, 4], + expected: vec![1, 3], + print: true, + }, + ]; + + for test in test_cases { + let mut deps = mock_dependencies(); + let mut tree = TreeNode::new( + book_id, + tick_id, + generate_node_id(deps.as_mut().storage, book_id, tick_id).unwrap(), + NodeType::internal(Uint128::zero(), (u32::MAX, u32::MIN)), + ); + + for node in test.nodes { + let mut tree_node = TreeNode::new( + book_id, + tick_id, + generate_node_id(deps.as_mut().storage, book_id, tick_id).unwrap(), + node, + ); + NODES + .save( + deps.as_mut().storage, + &(book_id, tick_id, tree_node.key), + &tree_node, + ) + .unwrap(); + tree.insert(deps.as_mut().storage, &mut tree_node).unwrap(); + } + + if test.print { + println!("Pre-Deletion Tree: {}", test.name); + println!("--------------------------"); + print_tree(deps.as_ref().storage, &tree, 0, true); + println!(); + } + + for key in test.delete.clone() { + let node = NODES + .load(deps.as_ref().storage, &(book_id, tick_id, key)) + .unwrap(); + node.delete(deps.as_mut().storage).unwrap(); + } + + if test.expected.is_empty() { + let maybe_parent = tree.get_parent(deps.as_ref().storage).unwrap(); + assert!(maybe_parent.is_none(), "Parent node should not exist"); + continue; + } + + let tree = NODES + .load(deps.as_ref().storage, &(book_id, tick_id, tree.key)) + .unwrap(); + + if test.print { + println!("Post-Deletion Tree: {}", test.name); + println!("--------------------------"); + print_tree(deps.as_ref().storage, &tree, 0, true); + println!(); + } + + let result = tree.traverse(deps.as_ref().storage).unwrap(); + + assert_eq!( + result, + test.expected + .iter() + .map(|key| NODES + .load(deps.as_ref().storage, &(book_id, tick_id, *key)) + .unwrap()) + .collect::>() + ); + + for key in test.delete { + let maybe_node = NODES + .may_load(deps.as_ref().storage, &(book_id, tick_id, key)) + .unwrap(); + assert!(maybe_node.is_none(), "Node {key} was not deleted"); + } + + let internals: Vec<&TreeNode> = result.iter().filter(|x| x.is_internal()).collect(); + for internal_node in internals { + let left_node = internal_node.get_left(deps.as_ref().storage).unwrap(); + let right_node = internal_node.get_right(deps.as_ref().storage).unwrap(); + + let accumulated_value = left_node + .clone() + .map(|x| x.get_value()) + .unwrap_or_default() + .checked_add( + right_node + .clone() + .map(|x| x.get_value()) + .unwrap_or_default(), + ) + .unwrap(); + assert_eq!(internal_node.get_value(), accumulated_value); + + let min = left_node + .clone() + .map(|n| n.get_min_range()) + .unwrap_or(Uint128::MAX) + .min( + right_node + .clone() + .map(|n| n.get_min_range()) + .unwrap_or(Uint128::MAX), + ); + let max = left_node + .map(|n| n.get_max_range()) + .unwrap_or(Uint128::MIN) + .max( + right_node + .map(|n| n.get_max_range()) + .unwrap_or(Uint128::MIN), + ); + assert_eq!(internal_node.get_min_range(), min); + assert_eq!(internal_node.get_max_range(), max); + } + } +}