Skip to content

Commit

Permalink
Take advantage of built-in retworkx macros
Browse files Browse the repository at this point in the history
  • Loading branch information
InnovativeInventor committed May 16, 2022
1 parent d843662 commit cd7cecc
Showing 1 changed file with 63 additions and 74 deletions.
137 changes: 63 additions & 74 deletions src/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ use super::{graph, weight_callable};

use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::{PyFloat, PyList};
use pyo3::Python;

use petgraph::algo::{connected_components, is_cyclic_undirected};
use petgraph::prelude::*;
use petgraph::stable_graph::EdgeReference;
use petgraph::unionfind::UnionFind;
use petgraph::visit::{IntoEdgeReferences, NodeIndexable};

use numpy::PyReadonlyArray1;

use rayon::prelude::*;

use crate::iterators::WeightedEdgeList;
Expand Down Expand Up @@ -165,84 +165,73 @@ pub fn minimum_spanning_tree(
pub fn balanced_cut_edge(
py: Python,
spanning_tree: &graph::PyGraph,
py_pops: &PyList,
py_pop_target: &PyFloat,
py_epsilon: &PyFloat,
pops: Vec<f64>,
pop_target: f64,
epsilon: f64,
) -> PyResult<Vec<(usize, Vec<usize>)>> {
let epsilon = py_epsilon.value();
let pop_target = py_pop_target.value();

let mut pops: Vec<f64> = vec![]; // not sure if the conversions are needed
for i in 0..py_pops.len() {
pops.push(py_pops.get_item(i).unwrap().extract::<f64>().unwrap());
}

let mut pops = pops.clone();
let spanning_tree_graph = &spanning_tree.graph;

let balanced_nodes = py.allow_threads(move || {
let mut same_partition_tracker: Vec<Vec<usize>> =
vec![vec![]; spanning_tree_graph.node_count()]; // keeps track of all all the nodes on the same side of the partition
let mut node_queue: VecDeque<NodeIndex> = VecDeque::<NodeIndex>::new();
for leaf_node in spanning_tree_graph.node_indices() {
// todo: filter expr
if spanning_tree_graph.neighbors(leaf_node).count() == 1 {
node_queue.push_back(leaf_node);
}
same_partition_tracker[leaf_node.index()].push(leaf_node.index());
let mut same_partition_tracker: Vec<Vec<usize>> =
vec![vec![]; spanning_tree_graph.node_count()]; // keeps track of all all the nodes on the same side of the partition
let mut node_queue: VecDeque<NodeIndex> = VecDeque::<NodeIndex>::new();
for leaf_node in spanning_tree_graph.node_indices() {
// todo: filter expr
if spanning_tree_graph.neighbors(leaf_node).count() == 1 {
node_queue.push_back(leaf_node);
}
same_partition_tracker[leaf_node.index()].push(leaf_node.index());
}

// eprintln!("leaf nodes: {}", node_queue.len());

// this process can be multithreaded, if the locking overhead isn't too high
// (note: locking may not even be needed given the invariants this is assumed to maintain)
let mut balanced_nodes: Vec<(usize, Vec<usize>)> = vec![];
let mut seen_nodes: Vec<bool> = vec![false; spanning_tree_graph.node_count()]; // todo: perf test this
while node_queue.len() > 0 {
let node = node_queue.pop_front().unwrap();
if seen_nodes[node.index()] {
// should not need this
// eprintln!("Invalid state! Double vision . . .");
continue;
}
let pop = pops[node.index()];
// eprintln!("leaf nodes: {}", node_queue.len());

// todo: factor out expensive clones
// Mark as seen; push to queue if only one unseen neighbor
let unseen_neighbors: Vec<NodeIndex> = spanning_tree
.graph
.neighbors(node)
.filter(|node| !seen_nodes[node.index()])
.collect();
// eprintln!("unseen_neighbors: {}", unseen_neighbors.len());
if unseen_neighbors.len() == 1 {
// this will be false if root
let neighbor = unseen_neighbors[0];
pops[neighbor.index()] += pop.clone();
let mut current_partition_tracker = same_partition_tracker[node.index()].clone();
same_partition_tracker[neighbor.index()].append(&mut current_partition_tracker);
// eprintln!("node pushed to queue (pop = {}, target = {}): {}", pops[neighbor.index()], pop_target, neighbor.index());
// this process can be multithreaded, if the locking overhead isn't too high
// (note: locking may not even be needed given the invariants this is assumed to maintain)
let mut balanced_nodes: Vec<(usize, Vec<usize>)> = vec![];
let mut seen_nodes: Vec<bool> = vec![false; spanning_tree_graph.node_count()]; // todo: perf test this
while node_queue.len() > 0 {
let node = node_queue.pop_front().unwrap();
if seen_nodes[node.index()] {
// should not need this
// eprintln!("Invalid state! Double vision . . .");
continue;
}
let pop = pops[node.index()];

if !node_queue.contains(&neighbor) {
node_queue.push_back(neighbor);
}
} else if unseen_neighbors.len() == 0 {
break;
} else {
continue;
}
// pops[node.index()] = 0.0; // not needed?
// todo: factor out expensive clones
// Mark as seen; push to queue if only one unseen neighbor
let unseen_neighbors: Vec<NodeIndex> = spanning_tree
.graph
.neighbors(node)
.filter(|node| !seen_nodes[node.index()])
.collect();
// eprintln!("unseen_neighbors: {}", unseen_neighbors.len());
if unseen_neighbors.len() == 1 {
// this will be false if root
let neighbor = unseen_neighbors[0];
pops[neighbor.index()] += pop.clone();
let mut current_partition_tracker = same_partition_tracker[node.index()].clone();
same_partition_tracker[neighbor.index()].append(&mut current_partition_tracker);
// eprintln!("node pushed to queue (pop = {}, target = {}): {}", pops[neighbor.index()], pop_target, neighbor.index());

// Check if balanced
if pop >= pop_target * (1.0 - epsilon) && pop <= pop_target * (1.0 + epsilon) {
// slightly different
// eprintln!("balanced node found: {}", node.index());
balanced_nodes.push((node.index(), same_partition_tracker[node.index()].clone()));
if !node_queue.contains(&neighbor) {
node_queue.push_back(neighbor);
}
} else if unseen_neighbors.len() == 0 {
break;
} else {
continue;
}
// pops[node.index()] = 0.0; // not needed?

seen_nodes[node.index()] = true;
// Check if balanced
if pop >= pop_target * (1.0 - epsilon) && pop <= pop_target * (1.0 + epsilon) {
// slightly different
// eprintln!("balanced node found: {}", node.index());
balanced_nodes.push((node.index(), same_partition_tracker[node.index()].clone()));
}
balanced_nodes
});

seen_nodes[node.index()] = true;
}

Ok(balanced_nodes)
}
Expand Down Expand Up @@ -274,9 +263,9 @@ pub fn bipartition_tree(
py: Python,
graph: &graph::PyGraph,
weight_fn: PyObject,
py_pops: Vec<f64>,
py_pop_target: f64,
py_epsilon: f64,
pops: Vec<f64>,
pop_target: f64,
epsilon: f64,
) -> PyResult<Vec<(usize, Vec<usize>)>> {
let mut balanced_nodes: Vec<(usize, Vec<usize>)> = vec![];

Expand All @@ -289,7 +278,7 @@ pub fn bipartition_tree(
let mst = minimum_spanning_tree(py, graph, Some(weight_fn.clone()), 1.0).unwrap();
// assert_eq!(is_cyclic_undirected(&mst.graph), false);
// assert_eq!(connected_components(&mst.graph), 1);
balanced_nodes = balanced_cut_edge(py, &mst, py_pops, py_pop_target, py_epsilon).unwrap();
balanced_nodes = balanced_cut_edge(py, &mst, pops.clone(), pop_target, epsilon).unwrap();
}

Ok(balanced_nodes)
Expand Down

0 comments on commit cd7cecc

Please sign in to comment.