Skip to content

Commit

Permalink
vtree: retrieve vtrees for literals
Browse files Browse the repository at this point in the history
  • Loading branch information
jsfpdn committed Oct 13, 2024
1 parent 088df0a commit 3a7de80
Showing 1 changed file with 57 additions and 1 deletion.
58 changes: 57 additions & 1 deletion sdd-rs-lib/vtree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{cell::RefCell, fmt::Debug, rc::Rc};

use crate::{
dot_writer::{Dot, DotWriter, Edge, NodeType},
literal::VarLabel,
literal::{Literal, VarLabel},
manager::SddManager,
};

Expand Down Expand Up @@ -366,6 +366,25 @@ impl VTreeManager {
VTreeManager::set_inorder_indices(self.root.clone().unwrap(), 0);
}

pub(crate) fn get_variable_vtree(&self, variable: &VarLabel) -> Option<VTreeRef> {
fn find_vtree(vtree: &VTreeRef, variable: &VarLabel) -> Option<VTreeRef> {
match vtree.borrow().node.clone() {
Node::Internal(lc, rc) => find_vtree(&lc, variable)
.or(find_vtree(&rc, variable))
.or(None),
Node::Leaf(candidate_variable) => {
if *variable == candidate_variable {
Some(vtree.clone())
} else {
None
}
}
}
}

find_vtree(&self.root.clone().unwrap(), variable)
}

fn get_vtree(&self, index: u16) -> Option<VTreeRef> {
// TODO: This will get obsolete once VTrees are stored in a single hashmap.
let Some(mut current) = self.root.clone() else {
Expand Down Expand Up @@ -782,4 +801,41 @@ mod test {
assert_eq!(ord, VTreeOrder::Inequal);
assert_eq!(lca.borrow().idx, root_idx);
}

#[test]
fn literal_indices() {
let var_label_index = |vtree: Option<VTreeRef>| -> u16 { vtree.unwrap().borrow().idx };

let mut manager = VTreeManager::new();
manager.add_variable(VarLabel::new("A"));
manager.add_variable(VarLabel::new("B"));
manager.add_variable(VarLabel::new("C"));
manager.add_variable(VarLabel::new("D"));
// 1
// / \
// 0 3
// A / \
// 2 5
// B / \
// 4 6
// C D

assert_eq!(
var_label_index(manager.get_variable_vtree(&VarLabel::new("A"))),
0
);
assert_eq!(
var_label_index(manager.get_variable_vtree(&VarLabel::new("B"))),
2
);
assert_eq!(
var_label_index(manager.get_variable_vtree(&VarLabel::new("C"))),
4
);
assert_eq!(
var_label_index(manager.get_variable_vtree(&VarLabel::new("D"))),
6
);
assert_eq!(manager.get_variable_vtree(&VarLabel::new("E")), None);
}
}

0 comments on commit 3a7de80

Please sign in to comment.