From 8001f6e6d693b0a21ea20d8d19c914aabd2530c1 Mon Sep 17 00:00:00 2001 From: gluonhiggs Date: Wed, 5 Jun 2024 15:09:12 +0700 Subject: [PATCH] modify the api layer, type annotations, the binder and the rust core --- ..._cycle_basis.rs => minimal_cycle_basis.rs} | 122 +++++++++++++----- rustworkx-core/src/connectivity/mod.rs | 4 +- rustworkx/__init__.py | 26 ++++ rustworkx/__init__.pyi | 2 +- rustworkx/rustworkx.pyi | 2 +- src/connectivity/minimum_cycle_basis.rs | 41 ++++++ src/connectivity/mod.rs | 24 +++- src/lib.rs | 1 + 8 files changed, 183 insertions(+), 39 deletions(-) rename rustworkx-core/src/connectivity/{minimum_cycle_basis.rs => minimal_cycle_basis.rs} (83%) create mode 100644 src/connectivity/minimum_cycle_basis.rs diff --git a/rustworkx-core/src/connectivity/minimum_cycle_basis.rs b/rustworkx-core/src/connectivity/minimal_cycle_basis.rs similarity index 83% rename from rustworkx-core/src/connectivity/minimum_cycle_basis.rs rename to rustworkx-core/src/connectivity/minimal_cycle_basis.rs index 8955ae8771..5ea97edb91 100644 --- a/rustworkx-core/src/connectivity/minimum_cycle_basis.rs +++ b/rustworkx-core/src/connectivity/minimal_cycle_basis.rs @@ -1,9 +1,9 @@ use crate::connectivity::conn_components::connected_components; use crate::dictmap::*; -use crate::shortest_path::dijkstra; +use crate::shortest_path::{astar, dijkstra}; use crate::Result; use hashbrown::{HashMap, HashSet}; -use petgraph::algo::{astar, min_spanning_tree, Measure}; +use petgraph::algo::{min_spanning_tree, Measure}; use petgraph::csr::{DefaultIx, IndexType}; use petgraph::data::{DataMap, Element}; use petgraph::graph::Graph; @@ -13,6 +13,7 @@ use petgraph::visit::{ IntoNeighborsDirected, IntoNodeIdentifiers, IntoNodeReferences, NodeIndexable, Visitable, }; use petgraph::Undirected; +use std::cmp::Ordering; use std::convert::Infallible; use std::hash::Hash; @@ -39,7 +40,7 @@ where G::NodeId: Eq + Hash, G::EdgeWeight: Clone, F: FnMut(G::EdgeRef) -> Result, - K: Clone + PartialOrd + Copy + Measure + Default + Ord, + K: Clone + PartialOrd + Copy + Measure + Default, { components .into_iter() @@ -77,7 +78,7 @@ where }) .collect() } -pub fn minimum_cycle_basis(graph: G, mut weight_fn: F) -> Result>, E> +pub fn minimal_cycle_basis(graph: G, mut weight_fn: F) -> Result>, E> where G: EdgeCount + IntoNodeIdentifiers @@ -88,10 +89,10 @@ where + IntoNeighborsDirected + Visitable + IntoEdges, - G::EdgeWeight: Clone + PartialOrd, + G::EdgeWeight: Clone, G::NodeId: Eq + Hash, F: FnMut(G::EdgeRef) -> Result, - K: Clone + PartialOrd + Copy + Measure + Default + Ord, + K: Clone + PartialOrd + Copy + Measure + Default, { let conn_components = connected_components(&graph); let mut min_cycle_basis = Vec::new(); @@ -136,7 +137,7 @@ where H::EdgeWeight: Clone + PartialOrd, H::NodeId: Eq + Hash, F: FnMut(H::EdgeRef) -> Result, - K: Clone + PartialOrd + Copy + Measure + Default + Ord, + K: Clone + PartialOrd + Copy + Measure + Default, { let mut sub_cb: Vec> = Vec::new(); let num_edges = subgraph.edge_count(); @@ -243,7 +244,7 @@ where H: IntoNodeReferences + IntoEdgeReferences + DataMap + NodeIndexable + EdgeIndexable, H::NodeId: Eq + Hash, F: FnMut(H::EdgeRef) -> Result, - K: Clone + PartialOrd + Copy + Measure + Default + Ord, + K: Clone + PartialOrd + Copy + Measure + Default, { let mut gi = Graph::<_, _, petgraph::Undirected>::default(); let mut subgraph_gi_map = HashMap::new(); @@ -290,23 +291,27 @@ where |edge| Ok(*edge.weight()), None, ); - // Find the shortest distance in the result and store it in the shortest_path_map let spl = result.unwrap()[&gi_lifted_nodeidx]; shortest_path_map.insert(subnodeid, spl); } - let min_start = shortest_path_map.iter().min_by_key(|x| x.1).unwrap().0; + let min_start = shortest_path_map + .iter() + .min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(Ordering::Equal)) + .unwrap() + .0; let min_start_node = subgraph_gi_map[min_start].0; let min_start_lifted_node = subgraph_gi_map[min_start].1; - let result = astar( + let result: Result)>> = astar( &gi, - min_start_node, - |finish| finish == min_start_lifted_node, - |e| *e.weight(), - |_| K::default(), + min_start_node.clone(), + |finish| Ok(finish == min_start_lifted_node.clone()), + |e| Ok(*e.weight()), + |_| Ok(K::default()), ); + let mut min_path: Vec = Vec::new(); match result { - Some((_cost, path)) => { + Ok(Some((_cost, path))) => { for node in path { if let Some(&subgraph_nodeid) = gi_subgraph_map.get(&node) { let subgraph_node = NodeIndexable::to_index(&subgraph, subgraph_nodeid); @@ -314,7 +319,8 @@ where } } } - None => {} + Ok(None) => {} + Err(_) => {} } let edgelist = min_path .windows(2) @@ -344,9 +350,9 @@ where } #[cfg(test)] -mod test_minimum_cycle_basis { - use crate::connectivity::minimum_cycle_basis::minimum_cycle_basis; - use petgraph::graph::Graph; +mod test_minimal_cycle_basis { + use crate::connectivity::minimal_cycle_basis::minimal_cycle_basis; + use petgraph::graph::{Graph, NodeIndex}; use petgraph::Undirected; use std::convert::Infallible; @@ -356,7 +362,7 @@ mod test_minimum_cycle_basis { let weight_fn = |edge: petgraph::graph::EdgeReference| -> Result { Ok(*edge.weight()) }; - let output = minimum_cycle_basis(&graph, weight_fn).unwrap(); + let output = minimal_cycle_basis(&graph, weight_fn).unwrap(); assert_eq!(output.len(), 0); } @@ -372,8 +378,7 @@ mod test_minimum_cycle_basis { let weight_fn = |edge: petgraph::graph::EdgeReference| -> Result { Ok(*edge.weight()) }; - let cycles = minimum_cycle_basis(&graph, weight_fn); - println!("Cycles {:?}", cycles.as_ref().unwrap()); + let cycles = minimal_cycle_basis(&graph, weight_fn); assert_eq!(cycles.unwrap().len(), 1); } @@ -393,10 +398,60 @@ mod test_minimum_cycle_basis { let weight_fn = |edge: petgraph::graph::EdgeReference| -> Result { Ok(*edge.weight()) }; - let cycles = minimum_cycle_basis(&graph, weight_fn); + let cycles = minimal_cycle_basis(&graph, weight_fn); assert_eq!(cycles.unwrap().len(), 2); } + #[test] + fn test_non_trivial_graph() { + let mut g = Graph::<&str, i32, Undirected>::new_undirected(); + let a = g.add_node("A"); + let b = g.add_node("B"); + let c = g.add_node("C"); + let d = g.add_node("D"); + let e = g.add_node("E"); + let f = g.add_node("F"); + + g.add_edge(a, b, 7); + g.add_edge(c, a, 9); + g.add_edge(a, d, 11); + g.add_edge(b, c, 10); + g.add_edge(d, c, 2); + g.add_edge(d, e, 9); + g.add_edge(b, f, 15); + g.add_edge(c, f, 11); + g.add_edge(e, f, 6); + + let weight_fn = |edge: petgraph::graph::EdgeReference| -> Result { + Ok(*edge.weight()) + }; + let output = minimal_cycle_basis(&g, weight_fn); + let mut actual_output = output.unwrap(); + for cycle in &mut actual_output { + cycle.sort(); + } + actual_output.sort(); + + let expected_output: Vec> = vec![ + vec![ + NodeIndex::new(5), + NodeIndex::new(2), + NodeIndex::new(3), + NodeIndex::new(4), + ], + vec![NodeIndex::new(2), NodeIndex::new(5), NodeIndex::new(1)], + vec![NodeIndex::new(0), NodeIndex::new(2), NodeIndex::new(1)], + vec![NodeIndex::new(2), NodeIndex::new(3), NodeIndex::new(0)], + ]; + let mut sorted_expected_output = expected_output.clone(); + for cycle in &mut sorted_expected_output { + cycle.sort(); + } + sorted_expected_output.sort(); + + assert_eq!(actual_output, sorted_expected_output); + } + #[test] fn test_weighted_diamond_graph() { let mut weighted_diamond = Graph::<(), i32, Undirected>::new_undirected(); @@ -412,20 +467,19 @@ mod test_minimum_cycle_basis { let weight_fn = |edge: petgraph::graph::EdgeReference| -> Result { Ok(*edge.weight()) }; - let output = minimum_cycle_basis(&weighted_diamond, weight_fn); - let expected_output: Vec> = vec![vec![0, 1, 3], vec![0, 1, 2, 3]]; + let output = minimal_cycle_basis(&weighted_diamond, weight_fn); + let expected_output1: Vec> = vec![vec![0, 1, 3], vec![0, 1, 2, 3]]; + let expected_output2: Vec> = vec![vec![1, 2, 3], vec![0, 1, 2, 3]]; for cycle in output.unwrap().iter() { - println!("{:?}", cycle); let mut node_indices: Vec = Vec::new(); for node in cycle.iter() { node_indices.push(node.index()); } node_indices.sort(); - println!("Node indices {:?}", node_indices); - if expected_output.contains(&node_indices) { - println!("Found cycle {:?}", node_indices); - } - assert!(expected_output.contains(&node_indices)); + assert!( + expected_output1.contains(&node_indices) + || expected_output2.contains(&node_indices) + ); } } @@ -444,7 +498,7 @@ mod test_minimum_cycle_basis { let weight_fn = |_edge: petgraph::graph::EdgeReference<()>| -> Result { Ok(1) }; - let output = minimum_cycle_basis(&unweighted_diamond, weight_fn); + let output = minimal_cycle_basis(&unweighted_diamond, weight_fn); let expected_output: Vec> = vec![vec![0, 1, 3], vec![1, 2, 3]]; for cycle in output.unwrap().iter() { let mut node_indices: Vec = Vec::new(); @@ -476,7 +530,7 @@ mod test_minimum_cycle_basis { let weight_fn = |edge: petgraph::graph::EdgeReference| -> Result { Ok(*edge.weight()) }; - let output = minimum_cycle_basis(&complete_graph, weight_fn); + let output = minimal_cycle_basis(&complete_graph, weight_fn); for cycle in output.unwrap().iter() { assert_eq!(cycle.len(), 3); } diff --git a/rustworkx-core/src/connectivity/mod.rs b/rustworkx-core/src/connectivity/mod.rs index 3e881c415b..3df41b25a9 100644 --- a/rustworkx-core/src/connectivity/mod.rs +++ b/rustworkx-core/src/connectivity/mod.rs @@ -21,7 +21,7 @@ mod cycle_basis; mod find_cycle; mod isolates; mod min_cut; -mod minimum_cycle_basis; +mod minimal_cycle_basis; pub use all_simple_paths::{ all_simple_paths_multiple_targets, longest_simple_path_multiple_targets, @@ -37,4 +37,4 @@ pub use cycle_basis::cycle_basis; pub use find_cycle::find_cycle; pub use isolates::isolates; pub use min_cut::stoer_wagner_min_cut; -pub use minimum_cycle_basis::minimum_cycle_basis; +pub use minimal_cycle_basis::minimal_cycle_basis; diff --git a/rustworkx/__init__.py b/rustworkx/__init__.py index 2943017fcc..f4da836ccf 100644 --- a/rustworkx/__init__.py +++ b/rustworkx/__init__.py @@ -557,6 +557,32 @@ def all_pairs_dijkstra_path_lengths(graph, edge_cost_fn): raise TypeError("Invalid Input Type %s for graph" % type(graph)) +@_rustworkx_dispatch +def minimum_cycle_basis(graph, edge_cost_fn): + """Find the minimum cycle basis of a graph. + + This function will find the minimum cycle basis of a graph based on the + following papers + References: + [1] Kavitha, Telikepalli, et al. "An O(m^2n) Algorithm for + Minimum Cycle Basis of Graphs." + http://link.springer.com/article/10.1007/s00453-007-9064-z + [2] de Pina, J. 1995. Applications of shortest path methods. + Ph.D. thesis, University of Amsterdam, Netherlands + + :param graph: The input graph to use. Can either be a + :class:`~rustworkx.PyGraph` or :class:`~rustworkx.PyDiGraph` + :param edge_cost_fn: A callable object that acts as a weight function for + an edge. It will accept a single positional argument, the edge's weight + object and will return a float which will be used to represent the + weight/cost of the edge + + :return: A list of cycles where each cycle is a list of node indices + + :rtype: list + """ + raise TypeError("Invalid Input Type %s for graph" % type(graph)) + @_rustworkx_dispatch def dijkstra_shortest_path_lengths(graph, node, edge_cost_fn, goal=None): """Compute the lengths of the shortest paths for a graph object using diff --git a/rustworkx/__init__.pyi b/rustworkx/__init__.pyi index 0e157e6aba..140757a265 100644 --- a/rustworkx/__init__.pyi +++ b/rustworkx/__init__.pyi @@ -83,7 +83,7 @@ from .rustworkx import graph_longest_simple_path as graph_longest_simple_path from .rustworkx import digraph_core_number as digraph_core_number from .rustworkx import graph_core_number as graph_core_number from .rustworkx import stoer_wagner_min_cut as stoer_wagner_min_cut -from .rustworkx import minimum_cycle_basis as minimum_cycle_basis +from .rustworkx import graph_minimum_cycle_basis as graph_minimum_cycle_basis from .rustworkx import simple_cycles as simple_cycles from .rustworkx import digraph_isolates as digraph_isolates from .rustworkx import graph_isolates as graph_isolates diff --git a/rustworkx/rustworkx.pyi b/rustworkx/rustworkx.pyi index 042d5276d3..2087df90f2 100644 --- a/rustworkx/rustworkx.pyi +++ b/rustworkx/rustworkx.pyi @@ -247,7 +247,7 @@ def stoer_wagner_min_cut( /, weight_fn: Callable[[_T], float] | None = ..., ) -> tuple[float, NodeIndices] | None: ... -def minimum_cycle_basis( +def graph_minimum_cycle_basis( graph: PyGraph[_S, _T], /, weight_fn: Callable[[_T], float] | None = ... diff --git a/src/connectivity/minimum_cycle_basis.rs b/src/connectivity/minimum_cycle_basis.rs new file mode 100644 index 0000000000..55b73739bf --- /dev/null +++ b/src/connectivity/minimum_cycle_basis.rs @@ -0,0 +1,41 @@ +use rustworkx_core::connectivity::minimal_cycle_basis; + +use pyo3::exceptions::PyIndexError; +use pyo3::prelude::*; +use pyo3::Python; + +use petgraph::graph::NodeIndex; +use petgraph::prelude::*; +use petgraph::visit::EdgeIndexable; +use petgraph::EdgeType; + +use crate::{CostFn, StablePyGraph}; + +pub fn minimum_cycle_basis_map( + py: Python, + graph: &StablePyGraph, + edge_cost_fn: PyObject, +) -> PyResult>> { + if graph.node_count() == 0 { + return Ok(vec![]); + } else if graph.edge_count() == 0 { + return Ok(vec![]); + } + let edge_cost_callable = CostFn::from(edge_cost_fn); + let mut edge_weights: Vec> = Vec::with_capacity(graph.edge_bound()); + for index in 0..=graph.edge_bound() { + let raw_weight = graph.edge_weight(EdgeIndex::new(index)); + match raw_weight { + Some(weight) => edge_weights.push(Some(edge_cost_callable.call(py, weight)?)), + None => edge_weights.push(None), + }; + } + let edge_cost = |e: EdgeIndex| -> PyResult { + match edge_weights[e.index()] { + Some(weight) => Ok(weight), + None => Err(PyIndexError::new_err("No edge found for index")), + } + }; + let cycle_basis = minimal_cycle_basis(graph, |e| edge_cost(e.id())).unwrap(); + Ok(cycle_basis) +} diff --git a/src/connectivity/mod.rs b/src/connectivity/mod.rs index 7b2075eea5..3ed475934c 100644 --- a/src/connectivity/mod.rs +++ b/src/connectivity/mod.rs @@ -14,6 +14,7 @@ mod all_pairs_all_simple_paths; mod johnson_simple_cycles; +mod minimum_cycle_basis; mod subgraphs; use super::{ @@ -22,10 +23,10 @@ use super::{ use hashbrown::{HashMap, HashSet}; +use petgraph::algo; use petgraph::stable_graph::NodeIndex; use petgraph::unionfind::UnionFind; use petgraph::visit::{EdgeRef, IntoEdgeReferences, NodeCount, NodeIndexable, Visitable}; -use petgraph::algo; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::PyDict; @@ -918,6 +919,27 @@ pub fn stoer_wagner_min_cut( })) } +#[pyfunction] +#[pyo3(text_signature = "(graph, edge_cost_fn, /)")] +pub fn graph_minimum_cycle_basis( + py: Python, + graph: &graph::PyGraph, + edge_cost_fn: PyObject, +) -> PyResult>> { + let basis = minimum_cycle_basis::minimum_cycle_basis_map(py, &graph.graph, edge_cost_fn); + Ok(basis + .into_iter() + .map(|cycle| { + cycle + .into_iter() + .map(|node| NodeIndices { + nodes: node.iter().map(|nx| nx.index()).collect(), + }) + .collect() + }) + .collect()) +} + /// Return the articulation points of an undirected graph. /// /// An articulation point or cut vertex is any node whose removal (along with diff --git a/src/lib.rs b/src/lib.rs index cce7c91755..3d9eb63e01 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -570,6 +570,7 @@ fn rustworkx(py: Python<'_>, m: &Bound) -> PyResult<()> { ))?; m.add_wrapped(wrap_pyfunction!(metric_closure))?; m.add_wrapped(wrap_pyfunction!(stoer_wagner_min_cut))?; + m.add_wrapped(wrap_pyfunction!(graph_minimum_cycle_basis))?; m.add_wrapped(wrap_pyfunction!(steiner_tree::steiner_tree))?; m.add_wrapped(wrap_pyfunction!(digraph_dfs_search))?; m.add_wrapped(wrap_pyfunction!(graph_dfs_search))?;