Skip to content

Commit

Permalink
Rebalancing initial
Browse files Browse the repository at this point in the history
  • Loading branch information
crnbarr93 committed Mar 20, 2024
1 parent 5db8a14 commit 7f02a6d
Show file tree
Hide file tree
Showing 3 changed files with 266 additions and 10 deletions.
95 changes: 89 additions & 6 deletions contracts/sumtree-orderbook/src/sumtree/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use cosmwasm_schema::cw_serde;
use cosmwasm_std::{ensure, Storage, Uint128};
use cw_storage_plus::Map;

use crate::{error::ContractResult, ContractError};
use crate::{error::ContractResult, sumtree::tree::TREE, ContractError};

pub const NODES: Map<&(u64, i64, u64), TreeNode> = Map::new("nodes");
pub const NODE_ID_COUNTER: Map<&(u64, i64), u64> = Map::new("node_id");
Expand Down Expand Up @@ -40,6 +40,8 @@ pub enum NodeType {
accumulator: Uint128,
// Range from min ETAS to max ETAS + value of max ETAS
range: (Uint128, Uint128),
// Amount of child nodes
weight: u64,
},
}

Expand All @@ -58,6 +60,7 @@ impl NodeType {
Self::Internal {
range: (range.0.into(), range.1.into()),
accumulator: accumulator.into(),
weight: 0,
}
}
}
Expand All @@ -67,6 +70,7 @@ impl Default for NodeType {
Self::Internal {
accumulator: Uint128::zero(),
range: (Uint128::MAX, Uint128::MIN),
weight: 0,
}
}
}
Expand Down Expand Up @@ -193,6 +197,23 @@ impl TreeNode {
self.set_value(self.get_value().checked_add(value)?)
}

pub fn get_weight(&self) -> u64 {
match self.node_type {
NodeType::Internal { weight, .. } => weight,
NodeType::Leaf { .. } => 0,
}
}

pub fn set_weight(&mut self, new_weight: u64) -> ContractResult<()> {
match &mut self.node_type {
NodeType::Internal { weight, .. } => {
*weight = new_weight;
Ok(())
}
NodeType::Leaf { .. } => Err(ContractError::InvalidNodeType),
}
}

/// Recalculates the range and accumulated value for a node and propagates it up the tree
///
/// Must be an internal node
Expand Down Expand Up @@ -283,6 +304,8 @@ impl TreeNode {
self.set_max_range(new_node.get_max_range())?;
}

self.set_weight(self.get_weight() + 1)?;

let maybe_left = self.get_left(storage)?;
let maybe_right = self.get_right(storage)?;

Expand Down Expand Up @@ -346,9 +369,12 @@ impl TreeNode {
let right_is_leaf = maybe_right
.clone()
.map_or(false, |right| !right.is_internal());
let is_higher_than_right_leaf = maybe_right.clone().map_or(false, |r| {
!r.is_internal() && new_node.get_min_range() >= r.get_max_range()
});

// Case 4
if left_is_leaf {
if left_is_leaf && !is_higher_than_right_leaf {
let mut left = maybe_left.unwrap();
let new_left = left.split(storage, new_node)?;
self.left = Some(new_left);
Expand All @@ -358,9 +384,6 @@ impl TreeNode {

// Case 5: Reordering
// TODO: Add edge case test for this
let is_higher_than_right_leaf = maybe_right.clone().map_or(false, |r| {
!r.is_internal() && new_node.get_min_range() >= r.get_max_range()
});
if is_higher_than_right_leaf && maybe_left.is_none() {
self.left = self.right;
self.right = Some(new_node.key);
Expand Down Expand Up @@ -455,6 +478,64 @@ impl TreeNode {
Ok(())
}

pub fn rotate_right(&mut self, storage: &mut dyn Storage) -> ContractResult<()> {
let maybe_left = self.get_left(storage)?;
ensure!(maybe_left.is_some(), ContractError::InvalidNodeType);

let mut left = maybe_left.unwrap();

left.parent = self.parent;
self.parent = Some(left.key);
self.left = left.right;

if let Some(mut right) = left.get_right(storage)? {
right.parent = Some(self.key);
right.save(storage)?;
}

left.right = Some(self.key);

left.save(storage)?;
self.save(storage)?;

self.sync_range_and_value(storage)?;

if left.parent.is_none() {
TREE.save(storage, &(left.book_id, left.tick_id), &left.key)?;
}

Ok(())
}

pub fn rotate_left(&mut self, storage: &mut dyn Storage) -> ContractResult<()> {
let maybe_right = self.get_right(storage)?;
ensure!(maybe_right.is_some(), ContractError::InvalidNodeType);

let mut right = maybe_right.unwrap();

right.parent = self.parent;
self.parent = Some(right.key);
self.right = right.left;

if let Some(mut left) = right.get_left(storage)? {
left.parent = Some(self.key);
left.save(storage)?;
}

right.left = Some(self.key);

right.save(storage)?;
self.save(storage)?;

self.sync_range_and_value(storage)?;

if right.parent.is_none() {
TREE.save(storage, &(right.book_id, right.tick_id), &right.key)?;
}

Ok(())
}

#[cfg(test)]
/// Depth first search traversal of tree
pub fn traverse(&self, storage: &dyn Storage) -> ContractResult<Vec<TreeNode>> {
Expand Down Expand Up @@ -515,7 +596,9 @@ impl Display for NodeType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
NodeType::Leaf { value, etas } => write!(f, "{etas} {value}"),
NodeType::Internal { accumulator, range } => {
NodeType::Internal {
accumulator, range, ..
} => {
write!(f, "{} {}-{}", accumulator, range.0, range.1)
}
}
Expand Down
164 changes: 163 additions & 1 deletion contracts/sumtree-orderbook/src/sumtree/test/test_node.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use cosmwasm_std::{testing::mock_dependencies, Uint128};

use crate::sumtree::node::{generate_node_id, NodeType, TreeNode, NODES};
use crate::sumtree::{
node::{generate_node_id, NodeType, TreeNode, NODES},
tree::get_root_node,
};

struct TestNodeInsertCase {
name: &'static str,
Expand Down Expand Up @@ -600,3 +603,162 @@ fn test_node_deletion_valid() {
}
}
}

enum BalanceDirection {
Left,
Right,
}

struct TreeRebalancingTestCase {
name: &'static str,
nodes: Vec<NodeType>,
// Depth first search ordering of node IDs (Could be improved?)
expected: Vec<u64>,
// Whether to print the tree
print: bool,
direction: BalanceDirection,
}

#[test]
fn test_tree_rebalancing() {
let book_id = 1;
let tick_id = 1;
let test_cases: Vec<TreeRebalancingTestCase> = vec![
TreeRebalancingTestCase {
name: "Left heavy tree",
nodes: vec![
NodeType::leaf(1u32, 1u32),
NodeType::leaf(9u32, 1u32),
NodeType::leaf(6u32, 1u32),
NodeType::leaf(3u32, 1u32),
NodeType::leaf(2u32, 1u32),
NodeType::leaf(4u32, 1u32),
],
expected: vec![],
print: true,
direction: BalanceDirection::Right,
},
TreeRebalancingTestCase {
name: "Right heavy tree",
nodes: vec![
NodeType::leaf(1u32, 1u32),
NodeType::leaf(5u32, 1u32),
NodeType::leaf(9u32, 1u32),
NodeType::leaf(6u32, 1u32),
// NodeType::leaf(3u32, 1u32),
// NodeType::leaf(2u32, 1u32),
// NodeType::leaf(4u32, 1u32),
],
expected: vec![],
print: true,
direction: BalanceDirection::Left,
},
];

for test in test_cases {
let mut deps = mock_dependencies();
let mut tree = TreeNode::new(
book_id,
tick_id,
generate_node_id(deps.as_mut().storage, book_id, tick_id).unwrap(),
NodeType::internal(Uint128::zero(), (u32::MAX, u32::MIN)),
);

for node in test.nodes {
let mut tree_node = TreeNode::new(
book_id,
tick_id,
generate_node_id(deps.as_mut().storage, book_id, tick_id).unwrap(),
node,
);
NODES
.save(
deps.as_mut().storage,
&(book_id, tick_id, tree_node.key),
&tree_node,
)
.unwrap();
tree.insert(deps.as_mut().storage, &mut tree_node).unwrap();
}

if test.print {
println!("Pre-Rotation Tree: {}", test.name);
println!("--------------------------");
let nodes = tree.traverse_bfs(deps.as_ref().storage).unwrap();
for (idx, row) in nodes.iter().enumerate() {
print_tree_row(row.clone(), idx == 0, (nodes.len() - idx - 1) as u32);
}
println!();
}

match test.direction {
BalanceDirection::Left => tree.rotate_left(deps.as_mut().storage).unwrap(),
BalanceDirection::Right => tree.rotate_right(deps.as_mut().storage).unwrap(),
}

let tree = get_root_node(deps.as_ref().storage, book_id, tick_id).unwrap();

if test.print {
println!("Post-Rotation Tree: {}", test.name);
println!("--------------------------");
let nodes = tree.traverse_bfs(deps.as_ref().storage).unwrap();
for (idx, row) in nodes.iter().enumerate() {
print_tree_row(row.clone(), idx == 0, (nodes.len() - idx - 1) as u32);
}
println!();
}

// let result = tree.traverse(deps.as_ref().storage).unwrap();

// assert_eq!(
// result,
// test.expected
// .iter()
// .map(|key| NODES
// .load(deps.as_ref().storage, &(book_id, tick_id, *key))
// .unwrap())
// .collect::<Vec<TreeNode>>()
// );

// Uncomment post rebalancing implementation
// let internals: Vec<&TreeNode> = result.iter().filter(|x| x.is_internal()).collect();
// for internal_node in internals {
// let left_node = internal_node.get_left(deps.as_ref().storage).unwrap();
// let right_node = internal_node.get_right(deps.as_ref().storage).unwrap();

// let accumulated_value = left_node
// .clone()
// .map(|x| x.get_value())
// .unwrap_or_default()
// .checked_add(
// right_node
// .clone()
// .map(|x| x.get_value())
// .unwrap_or_default(),
// )
// .unwrap();
// assert_eq!(internal_node.get_value(), accumulated_value);

// let min = left_node
// .clone()
// .map(|n| n.get_min_range())
// .unwrap_or(Uint128::MAX)
// .min(
// right_node
// .clone()
// .map(|n| n.get_min_range())
// .unwrap_or(Uint128::MAX),
// );
// let max = left_node
// .map(|n| n.get_max_range())
// .unwrap_or(Uint128::MIN)
// .max(
// right_node
// .map(|n| n.get_max_range())
// .unwrap_or(Uint128::MIN),
// );
// assert_eq!(internal_node.get_min_range(), min);
// assert_eq!(internal_node.get_max_range(), max);
// }
}
}
17 changes: 14 additions & 3 deletions contracts/sumtree-orderbook/src/sumtree/tree.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
use cosmwasm_std::Storage;
use cw_storage_plus::Map;

use super::node::TreeNode;
use crate::error::ContractResult;

use super::node::{TreeNode, NODES};

// TODO: REMOVE
#[allow(dead_code)]
pub const TREE: Map<&(u64, i64), TreeNode> = Map::new("tree");
pub const TREE: Map<&(u64, i64), u64> = Map::new("tree");

pub fn get_root_node(
storage: &dyn Storage,
book_id: u64,
tick_id: i64,
) -> ContractResult<TreeNode> {
let root_id = TREE.load(storage, &(book_id, tick_id))?;
Ok(NODES.load(storage, &(book_id, tick_id, root_id))?)
}

0 comments on commit 7f02a6d

Please sign in to comment.