diff --git a/contracts/sumtree-orderbook/src/error.rs b/contracts/sumtree-orderbook/src/error.rs index a2ac754..752cd8c 100644 --- a/contracts/sumtree-orderbook/src/error.rs +++ b/contracts/sumtree-orderbook/src/error.rs @@ -77,6 +77,9 @@ pub enum ContractError { #[error("Invalid Node Type")] InvalidNodeType, + + #[error("Childless Internal Node")] + ChildlessInternalNode, } pub type ContractResult = Result; diff --git a/contracts/sumtree-orderbook/src/sumtree/node.rs b/contracts/sumtree-orderbook/src/sumtree/node.rs index c7df4d0..5ccccb5 100644 --- a/contracts/sumtree-orderbook/src/sumtree/node.rs +++ b/contracts/sumtree-orderbook/src/sumtree/node.rs @@ -118,6 +118,18 @@ impl TreeNode { } } + pub fn get_parent(&self, storage: &dyn Storage) -> ContractResult> { + if let Some(parent) = self.parent { + Ok(NODES.may_load(storage, &(self.book_id, self.tick_id, parent))?) + } else { + Ok(None) + } + } + + pub fn has_child(&self) -> bool { + self.left.is_some() || self.right.is_some() + } + pub fn save(&self, storage: &mut dyn Storage) -> ContractResult<()> { Ok(NODES.save(storage, &(self.book_id, self.tick_id, self.key), self)?) } @@ -164,19 +176,73 @@ impl TreeNode { } } - /// Adds a given value to an internal node's accumulator - /// - /// Errors if given node is not internal - pub fn add_value(&mut self, value: Uint128) -> ContractResult<()> { + pub fn set_value(&mut self, value: Uint128) -> ContractResult<()> { match &mut self.node_type { NodeType::Internal { accumulator, .. } => { - *accumulator = accumulator.checked_add(value)?; + *accumulator = value; Ok(()) } NodeType::Leaf { .. } => Err(ContractError::InvalidNodeType), } } + /// Adds a given value to an internal node's accumulator + /// + /// Errors if given node is not internal + pub fn add_value(&mut self, value: Uint128) -> ContractResult<()> { + self.set_value(self.get_value().checked_add(value)?) + } + + /// Recalculates the range and accumulated value for a node and propagates it up the tree + /// + /// Must be an internal node + pub fn sync_range_and_value(&mut self, storage: &mut dyn Storage) -> ContractResult<()> { + ensure!(self.is_internal(), ContractError::InvalidNodeType); + let maybe_left = self.get_left(storage)?; + let maybe_right = self.get_right(storage)?; + + let left_exists = maybe_left.is_some(); + let right_exists = maybe_right.is_some(); + + if !self.has_child() { + return Err(ContractError::ChildlessInternalNode); + } + + let (min, max) = if left_exists && !right_exists { + let left = maybe_left.clone().unwrap(); + (left.get_min_range(), left.get_max_range()) + } else if right_exists && !left_exists { + let right = maybe_right.clone().unwrap(); + (right.get_min_range(), right.get_max_range()) + } else { + let left = maybe_left.clone().unwrap(); + let right = maybe_right.clone().unwrap(); + + ( + left.get_min_range().min(right.get_min_range()), + left.get_max_range().max(right.get_max_range()), + ) + }; + + self.set_min_range(min)?; + self.set_max_range(max)?; + + let value = maybe_left + .map(|n| n.get_value()) + .unwrap_or_default() + .checked_add(maybe_right.map(|n| n.get_value()).unwrap_or_default())?; + self.set_value(value)?; + + // Must save before propagating as parent will read this node + self.save(storage)?; + + if let Some(mut parent) = self.get_parent(storage)? { + parent.sync_range_and_value(storage)?; + } + + Ok(()) + } + /// Gets the value for a given node. /// /// For `Leaf` nodes this is the `value`. @@ -284,7 +350,7 @@ impl TreeNode { // Case 4 if left_is_leaf { let mut left = maybe_left.unwrap(); - let new_left = left.split(storage, new_node, self.key)?; + let new_left = left.split(storage, new_node)?; self.left = Some(new_left); self.save(storage)?; return Ok(()); @@ -307,7 +373,7 @@ impl TreeNode { // Case 5 if !is_in_left_range && right_is_leaf { let mut right = maybe_right.unwrap(); - let new_right = right.split(storage, new_node, self.key)?; + let new_right = right.split(storage, new_node)?; self.right = Some(new_right); self.save(storage)?; return Ok(()); @@ -324,7 +390,6 @@ impl TreeNode { &mut self, storage: &mut dyn Storage, new_node: &mut TreeNode, - parent_id: u64, ) -> ContractResult { ensure!(!self.is_internal(), ContractError::InvalidNodeType); let id = generate_node_id(storage, self.book_id, self.tick_id)?; @@ -350,11 +415,11 @@ impl TreeNode { ); // Save new key references - self.parent = Some(id); - new_node.parent = Some(id); + new_parent.parent = self.parent; new_parent.left = Some(new_left); new_parent.right = Some(new_right); - new_parent.parent = Some(parent_id); + self.parent = Some(id); + new_node.parent = Some(id); new_parent.save(storage)?; self.save(storage)?; @@ -363,6 +428,33 @@ impl TreeNode { Ok(id) } + /// Deletes a given node from the tree and propagates value changes up through its parent nodes. + /// + /// If the parent node has no children after removal it is also deleted recursively, to prune empty branches. + pub fn delete(&self, storage: &mut dyn Storage) -> ContractResult<()> { + let maybe_parent = self.get_parent(storage)?; + if let Some(mut parent) = maybe_parent { + // Remove node reference from parent + if parent.left == Some(self.key) { + parent.left = None; + } else if parent.right == Some(self.key) { + parent.right = None; + } + + if !parent.has_child() { + // Remove no-children parents + parent.delete(storage)?; + } else { + // Update parents values after removing node + parent.sync_range_and_value(storage)?; + } + } + + NODES.remove(storage, &(self.book_id, self.tick_id, self.key)); + + Ok(()) + } + #[cfg(test)] /// Depth first search traversal of tree pub fn traverse(&self, storage: &dyn Storage) -> ContractResult> { diff --git a/contracts/sumtree-orderbook/src/sumtree/test/test_node.rs b/contracts/sumtree-orderbook/src/sumtree/test/test_node.rs index 0f3464e..a5a09c0 100644 --- a/contracts/sumtree-orderbook/src/sumtree/test/test_node.rs +++ b/contracts/sumtree-orderbook/src/sumtree/test/test_node.rs @@ -351,3 +351,252 @@ pub fn print_tree_row(row: Vec<(Option, Option)>, top: bool, } println!("{line}") } + +struct NodeDeletionTestCase { + name: &'static str, + nodes: Vec, + delete: Vec, + // Depth first search ordering of node IDs (Could be improved?) + expected: Vec, + // Whether to print the tree + print: bool, +} + +#[test] +fn test_node_deletion_valid() { + let book_id = 1; + let tick_id = 1; + let test_cases: Vec = vec![ + // Pre + // --- + // 1: 10 1-11 + // ┌──────── + // ->2: 1 10 + // + // Post + // ---- + // No tree + NodeDeletionTestCase { + name: "Remove only node", + nodes: vec![NodeType::leaf(1u32, 10u32)], + delete: vec![2], + expected: vec![], + print: true, + }, + // Pre + // --- + // 1: 15 1-16 + // ┌────────────────┐ + // ->2: 1 10 3: 11 5 + // + // Post + // ---- + // 1: 5 11-16 + // ────────┐ + // 3: 11 5 + NodeDeletionTestCase { + name: "Remove one of two nodes", + nodes: vec![NodeType::leaf(1u32, 10u32), NodeType::leaf(11u32, 5u32)], + delete: vec![2], + expected: vec![1, 3], + print: true, + }, + // Pre + // --- + // 1: 25 1-26 + // ┌────────────────────────────────┐ + // 5: 20 1-21 3: 21 5 + // ┌────────────────┐ + // ->2: 1 10 4: 11 10 + // + // Post + // ---- + // 1: 15 11-26 + // ┌────────────────────────────────┐ + // 5: 10 11-21 3: 21 5 + // ────────┐ + // 4: 11 10 + NodeDeletionTestCase { + name: "Remove nested node", + nodes: vec![ + NodeType::leaf(1u32, 10u32), + NodeType::leaf(21u32, 5u32), + NodeType::leaf(11u32, 10u32), + ], + delete: vec![2], + expected: vec![1, 5, 4, 3], + print: true, + }, + // Pre + // --- + // 1: 25 1-26 + // ┌────────────────────────────────┐ + // 5: 20 1-21 3: 21 5 + // ┌────────────────┐ + // ->2: 1 10 ->4: 11 10 + // + // Post + // ---- + // 1: 5 21-26 + // ────────┐ + // 3: 21 5 + NodeDeletionTestCase { + name: "Remove both children of internal", + nodes: vec![ + NodeType::leaf(1u32, 10u32), + NodeType::leaf(21u32, 5u32), + NodeType::leaf(11u32, 10u32), + ], + delete: vec![2, 4], + expected: vec![1, 3], + print: true, + }, + // Pre + // --- + // 1: 25 1-26 + // ┌────────────────────────────────┐ + // ->5: 20 1-21 3: 21 5 + // ┌────────────────┐ + // 2: 1 10 4: 11 10 + // + // Post + // ---- + // 1: 5 21-26 + // ────────┐ + // 3: 21 5 + NodeDeletionTestCase { + name: "Remove parent node", + nodes: vec![ + NodeType::leaf(1u32, 10u32), + NodeType::leaf(21u32, 5u32), + NodeType::leaf(11u32, 10u32), + ], + delete: vec![5], + expected: vec![1, 3], + print: true, + }, + ]; + + for test in test_cases { + let mut deps = mock_dependencies(); + let mut tree = TreeNode::new( + book_id, + tick_id, + generate_node_id(deps.as_mut().storage, book_id, tick_id).unwrap(), + NodeType::internal(Uint128::zero(), (u32::MAX, u32::MIN)), + ); + + for node in test.nodes { + let mut tree_node = TreeNode::new( + book_id, + tick_id, + generate_node_id(deps.as_mut().storage, book_id, tick_id).unwrap(), + node, + ); + NODES + .save( + deps.as_mut().storage, + &(book_id, tick_id, tree_node.key), + &tree_node, + ) + .unwrap(); + tree.insert(deps.as_mut().storage, &mut tree_node).unwrap(); + } + + if test.print { + println!("Pre-Deletion Tree: {}", test.name); + println!("--------------------------"); + let nodes = tree.traverse_bfs(deps.as_ref().storage).unwrap(); + for (idx, row) in nodes.iter().enumerate() { + print_tree_row(row.clone(), idx == 0, (nodes.len() - idx - 1) as u32); + } + println!(); + } + + for key in test.delete.clone() { + let node = NODES + .load(deps.as_ref().storage, &(book_id, tick_id, key)) + .unwrap(); + node.delete(deps.as_mut().storage).unwrap(); + } + + if test.expected.is_empty() { + let maybe_parent = tree.get_parent(deps.as_ref().storage).unwrap(); + assert!(maybe_parent.is_none(), "Parent node should not exist"); + continue; + } + + let tree = NODES + .load(deps.as_ref().storage, &(book_id, tick_id, tree.key)) + .unwrap(); + + if test.print { + println!("Post-Deletion Tree: {}", test.name); + println!("--------------------------"); + let nodes = tree.traverse_bfs(deps.as_ref().storage).unwrap(); + for (idx, row) in nodes.iter().enumerate() { + print_tree_row(row.clone(), idx == 0, (nodes.len() - idx - 1) as u32); + } + println!(); + } + + let result = tree.traverse(deps.as_ref().storage).unwrap(); + + assert_eq!( + result, + test.expected + .iter() + .map(|key| NODES + .load(deps.as_ref().storage, &(book_id, tick_id, *key)) + .unwrap()) + .collect::>() + ); + + for key in test.delete { + let maybe_node = NODES + .may_load(deps.as_ref().storage, &(book_id, tick_id, key)) + .unwrap(); + assert!(maybe_node.is_none(), "Node {key} was not deleted"); + } + + let internals: Vec<&TreeNode> = result.iter().filter(|x| x.is_internal()).collect(); + for internal_node in internals { + let left_node = internal_node.get_left(deps.as_ref().storage).unwrap(); + let right_node = internal_node.get_right(deps.as_ref().storage).unwrap(); + + let accumulated_value = left_node + .clone() + .map(|x| x.get_value()) + .unwrap_or_default() + .checked_add( + right_node + .clone() + .map(|x| x.get_value()) + .unwrap_or_default(), + ) + .unwrap(); + assert_eq!(internal_node.get_value(), accumulated_value); + + let min = left_node + .clone() + .map(|n| n.get_min_range()) + .unwrap_or(Uint128::MAX) + .min( + right_node + .clone() + .map(|n| n.get_min_range()) + .unwrap_or(Uint128::MAX), + ); + let max = left_node + .map(|n| n.get_max_range()) + .unwrap_or(Uint128::MIN) + .max( + right_node + .map(|n| n.get_max_range()) + .unwrap_or(Uint128::MIN), + ); + assert_eq!(internal_node.get_min_range(), min); + assert_eq!(internal_node.get_max_range(), max); + } + } +}