diff --git a/src/tree.rs b/src/tree.rs index 17460904b..c43d8db6c 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -17,15 +17,15 @@ use super::{graph, weight_callable}; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; -use pyo3::types::{PyFloat, PyList}; use pyo3::Python; -use petgraph::algo::{connected_components, is_cyclic_undirected}; use petgraph::prelude::*; use petgraph::stable_graph::EdgeReference; use petgraph::unionfind::UnionFind; use petgraph::visit::{IntoEdgeReferences, NodeIndexable}; +use numpy::PyReadonlyArray1; + use rayon::prelude::*; use crate::iterators::WeightedEdgeList; @@ -165,84 +165,73 @@ pub fn minimum_spanning_tree( pub fn balanced_cut_edge( py: Python, spanning_tree: &graph::PyGraph, - py_pops: &PyList, - py_pop_target: &PyFloat, - py_epsilon: &PyFloat, + pops: Vec, + pop_target: f64, + epsilon: f64, ) -> PyResult)>> { - let epsilon = py_epsilon.value(); - let pop_target = py_pop_target.value(); - - let mut pops: Vec = vec![]; // not sure if the conversions are needed - for i in 0..py_pops.len() { - pops.push(py_pops.get_item(i).unwrap().extract::().unwrap()); - } - + let mut pops = pops.clone(); let spanning_tree_graph = &spanning_tree.graph; - - let balanced_nodes = py.allow_threads(move || { - let mut same_partition_tracker: Vec> = - vec![vec![]; spanning_tree_graph.node_count()]; // keeps track of all all the nodes on the same side of the partition - let mut node_queue: VecDeque = VecDeque::::new(); - for leaf_node in spanning_tree_graph.node_indices() { - // todo: filter expr - if spanning_tree_graph.neighbors(leaf_node).count() == 1 { - node_queue.push_back(leaf_node); - } - same_partition_tracker[leaf_node.index()].push(leaf_node.index()); + let mut same_partition_tracker: Vec> = + vec![vec![]; spanning_tree_graph.node_count()]; // keeps track of all all the nodes on the same side of the partition + let mut node_queue: VecDeque = VecDeque::::new(); + for leaf_node in spanning_tree_graph.node_indices() { + // todo: filter expr + if spanning_tree_graph.neighbors(leaf_node).count() == 1 { + node_queue.push_back(leaf_node); } + same_partition_tracker[leaf_node.index()].push(leaf_node.index()); + } - // eprintln!("leaf nodes: {}", node_queue.len()); - - // this process can be multithreaded, if the locking overhead isn't too high - // (note: locking may not even be needed given the invariants this is assumed to maintain) - let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; - let mut seen_nodes: Vec = vec![false; spanning_tree_graph.node_count()]; // todo: perf test this - while node_queue.len() > 0 { - let node = node_queue.pop_front().unwrap(); - if seen_nodes[node.index()] { - // should not need this - // eprintln!("Invalid state! Double vision . . ."); - continue; - } - let pop = pops[node.index()]; + // eprintln!("leaf nodes: {}", node_queue.len()); - // todo: factor out expensive clones - // Mark as seen; push to queue if only one unseen neighbor - let unseen_neighbors: Vec = spanning_tree - .graph - .neighbors(node) - .filter(|node| !seen_nodes[node.index()]) - .collect(); - // eprintln!("unseen_neighbors: {}", unseen_neighbors.len()); - if unseen_neighbors.len() == 1 { - // this will be false if root - let neighbor = unseen_neighbors[0]; - pops[neighbor.index()] += pop.clone(); - let mut current_partition_tracker = same_partition_tracker[node.index()].clone(); - same_partition_tracker[neighbor.index()].append(&mut current_partition_tracker); - // eprintln!("node pushed to queue (pop = {}, target = {}): {}", pops[neighbor.index()], pop_target, neighbor.index()); + // this process can be multithreaded, if the locking overhead isn't too high + // (note: locking may not even be needed given the invariants this is assumed to maintain) + let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; + let mut seen_nodes: Vec = vec![false; spanning_tree_graph.node_count()]; // todo: perf test this + while node_queue.len() > 0 { + let node = node_queue.pop_front().unwrap(); + if seen_nodes[node.index()] { + // should not need this + // eprintln!("Invalid state! Double vision . . ."); + continue; + } + let pop = pops[node.index()]; - if !node_queue.contains(&neighbor) { - node_queue.push_back(neighbor); - } - } else if unseen_neighbors.len() == 0 { - break; - } else { - continue; - } - // pops[node.index()] = 0.0; // not needed? + // todo: factor out expensive clones + // Mark as seen; push to queue if only one unseen neighbor + let unseen_neighbors: Vec = spanning_tree + .graph + .neighbors(node) + .filter(|node| !seen_nodes[node.index()]) + .collect(); + // eprintln!("unseen_neighbors: {}", unseen_neighbors.len()); + if unseen_neighbors.len() == 1 { + // this will be false if root + let neighbor = unseen_neighbors[0]; + pops[neighbor.index()] += pop.clone(); + let mut current_partition_tracker = same_partition_tracker[node.index()].clone(); + same_partition_tracker[neighbor.index()].append(&mut current_partition_tracker); + // eprintln!("node pushed to queue (pop = {}, target = {}): {}", pops[neighbor.index()], pop_target, neighbor.index()); - // Check if balanced - if pop >= pop_target * (1.0 - epsilon) && pop <= pop_target * (1.0 + epsilon) { - // slightly different - // eprintln!("balanced node found: {}", node.index()); - balanced_nodes.push((node.index(), same_partition_tracker[node.index()].clone())); + if !node_queue.contains(&neighbor) { + node_queue.push_back(neighbor); } + } else if unseen_neighbors.len() == 0 { + break; + } else { + continue; + } + // pops[node.index()] = 0.0; // not needed? - seen_nodes[node.index()] = true; + // Check if balanced + if pop >= pop_target * (1.0 - epsilon) && pop <= pop_target * (1.0 + epsilon) { + // slightly different + // eprintln!("balanced node found: {}", node.index()); + balanced_nodes.push((node.index(), same_partition_tracker[node.index()].clone())); } - balanced_nodes - }); + + seen_nodes[node.index()] = true; + } Ok(balanced_nodes) } @@ -274,9 +263,9 @@ pub fn bipartition_tree( py: Python, graph: &graph::PyGraph, weight_fn: PyObject, - py_pops: Vec, - py_pop_target: f64, - py_epsilon: f64, + pops: Vec, + pop_target: f64, + epsilon: f64, ) -> PyResult)>> { let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; @@ -289,7 +278,7 @@ pub fn bipartition_tree( let mst = minimum_spanning_tree(py, graph, Some(weight_fn.clone()), 1.0).unwrap(); // assert_eq!(is_cyclic_undirected(&mst.graph), false); // assert_eq!(connected_components(&mst.graph), 1); - balanced_nodes = balanced_cut_edge(py, &mst, py_pops, py_pop_target, py_epsilon).unwrap(); + balanced_nodes = balanced_cut_edge(py, &mst, pops.clone(), pop_target, epsilon).unwrap(); } Ok(balanced_nodes)