Skip to content

Commit

Permalink
modified the binder and API layer
Browse files Browse the repository at this point in the history
  • Loading branch information
gluonhiggs committed Jun 12, 2024
1 parent 349e56a commit d5304e6
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 24 deletions.
5 changes: 2 additions & 3 deletions rustworkx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,15 +570,14 @@ def minimum_cycle_basis(graph, edge_cost_fn):
[2] de Pina, J. 1995. Applications of shortest path methods.
Ph.D. thesis, University of Amsterdam, Netherlands
:param graph: The input graph to use. Can either be a
:param graph: The input graph to use. Can be either a
:class:`~rustworkx.PyGraph` or :class:`~rustworkx.PyDiGraph`
:param edge_cost_fn: A callable object that acts as a weight function for
an edge. It will accept a single positional argument, the edge's weight
object and will return a float which will be used to represent the
weight/cost of the edge
:return: A list of cycles where each cycle is a list of node indices
:returns: A list of cycles where each cycle is a list of node indices
:rtype: list
"""
raise TypeError("Invalid Input Type %s for graph" % type(graph))
Expand Down
1 change: 1 addition & 0 deletions rustworkx/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ from .rustworkx import digraph_core_number as digraph_core_number
from .rustworkx import graph_core_number as graph_core_number
from .rustworkx import stoer_wagner_min_cut as stoer_wagner_min_cut
from .rustworkx import graph_minimum_cycle_basis as graph_minimum_cycle_basis
from .rustworkx import digraph_minimum_cycle_basis as digraph_minimum_cycle_basis
from .rustworkx import simple_cycles as simple_cycles
from .rustworkx import digraph_isolates as digraph_isolates
from .rustworkx import graph_isolates as graph_isolates
Expand Down
12 changes: 10 additions & 2 deletions rustworkx/rustworkx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ from typing import (
Callable,
Iterable,
Iterator,
Union,
final,
Sequence,
Any,
Expand Down Expand Up @@ -248,8 +249,15 @@ def stoer_wagner_min_cut(
weight_fn: Callable[[_T], float] | None = ...,
) -> tuple[float, NodeIndices] | None: ...
def graph_minimum_cycle_basis(
graph: PyGraph[_S, _T], /, weight_fn: Callable[[_T], float] | None = ...
) -> list[list[NodeIndices]] | None: ...
graph: PyGraph[_S, _T],
edge_cost: Callable[[_T], float],
/,
) -> list[list[NodeIndices]]: ...
def digraph_minimum_cycle_basis(
graph: PyDiGraph[_S, _T],
edge_cost: Callable[[_T], float],
/,
) -> list[list[NodeIndices]]: ...
def simple_cycles(graph: PyDiGraph, /) -> Iterator[NodeIndices]: ...
def graph_isolates(graph: PyGraph) -> NodeIndices: ...
def digraph_isolates(graph: PyDiGraph) -> NodeIndices: ...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,17 @@ use pyo3::exceptions::PyIndexError;
use pyo3::prelude::*;
use pyo3::Python;

use petgraph::graph::NodeIndex;
use crate::iterators::NodeIndices;
use crate::{CostFn, StablePyGraph};
use petgraph::prelude::*;
use petgraph::visit::EdgeIndexable;
use petgraph::EdgeType;

use crate::{CostFn, StablePyGraph};

pub fn minimum_cycle_basis_map<Ty: EdgeType + Sync>(
pub fn minimum_cycle_basis<Ty: EdgeType + Sync>(
py: Python,
graph: &StablePyGraph<Ty>,
edge_cost_fn: PyObject,
) -> PyResult<Vec<Vec<NodeIndex>>> {
) -> PyResult<Vec<Vec<NodeIndices>>> {
if graph.node_count() == 0 || graph.edge_count() == 0 {
return Ok(vec![]);
}
Expand All @@ -35,5 +34,17 @@ pub fn minimum_cycle_basis_map<Ty: EdgeType + Sync>(
}
};
let cycle_basis = minimal_cycle_basis(graph, |e| edge_cost(e.id())).unwrap();
Ok(cycle_basis)
// Convert the cycle basis to a list of lists of node indices
let result: Vec<Vec<NodeIndices>> = cycle_basis
.into_iter()
.map(|cycle| {
cycle
.into_iter()
.map(|node| NodeIndices {
nodes: vec![node.index()],
})
.collect()
})
.collect();
Ok(result)
}
55 changes: 42 additions & 13 deletions src/connectivity/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

mod all_pairs_all_simple_paths;
mod johnson_simple_cycles;
mod minimum_cycle_basis;
mod min_cycle_basis;
mod subgraphs;

use super::{
Expand Down Expand Up @@ -919,25 +919,54 @@ pub fn stoer_wagner_min_cut(
}))
}

/// Find a minimum cycle basis of an undirected graph.
/// All weights must be nonnegative. If the input graph does not have
/// any nodes or edges, this function returns ``None``.
/// If the input graph does not any weight, this function will find the
/// minimum cycle basis with the weight of 1.0 for all edges.
///
/// :param PyGraph: The undirected graph to be used
/// :param Callable edge_cost_fn: An optional callable object (function, lambda, etc) which
/// will be passed the edge object and expected to return a ``float``.
/// Edges with ``NaN`` weights will be considered to have 1.0 weight.
/// If ``edge_cost_fn`` is not specified a default value of ``1.0`` will be used for all edges.
///
/// :returns: A list of cycles, where each cycle is a list of node indices
/// :rtype: list
#[pyfunction]
#[pyo3(text_signature = "(graph, edge_cost_fn, /)")]
pub fn graph_minimum_cycle_basis(
py: Python,
graph: &graph::PyGraph,
edge_cost_fn: PyObject,
) -> PyResult<Vec<Vec<NodeIndices>>> {
let basis = minimum_cycle_basis::minimum_cycle_basis_map(py, &graph.graph, edge_cost_fn);
Ok(basis
.into_iter()
.map(|cycle| {
cycle
.into_iter()
.map(|node| NodeIndices {
nodes: node.iter().map(|nx| nx.index()).collect(),
})
.collect()
})
.collect())
min_cycle_basis::minimum_cycle_basis(py, &graph.graph, edge_cost_fn)
}

/// Find a minimum cycle basis of a directed graph (which is not of interest in the context
/// of minimum cycle basis). This function will return the minimum cycle basis of the
/// underlying undirected graph of the input directed graph.
/// All weights must be nonnegative. If the input graph does not have
/// any nodes or edges, this function returns ``None``.
/// If the input graph does not any weight, this function will find the
/// minimum cycle basis with the weight of 1.0 for all edges.
///
/// :param PyDiGraph: The directed graph to be used
/// :param Callable edge_cost_fn: An optional callable object (function, lambda, etc) which
/// will be passed the edge object and expected to return a ``float``.
/// Edges with ``NaN`` weights will be considered to have 1.0 weight.
/// If ``edge_cost_fn`` is not specified a default value of ``1.0`` will be used for all edges.
///
/// :returns: A list of cycles, where each cycle is a list of node indices
/// :rtype: list
#[pyfunction]
#[pyo3(text_signature = "(graph, edge_cost_fn, /)")]
pub fn digraph_minimum_cycle_basis(
py: Python,
graph: &digraph::PyDiGraph,
edge_cost_fn: PyObject,
) -> PyResult<Vec<Vec<NodeIndices>>> {
min_cycle_basis::minimum_cycle_basis(py, &graph.graph, edge_cost_fn)
}

/// Return the articulation points of an undirected graph.
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,7 @@ fn rustworkx(py: Python<'_>, m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(metric_closure))?;
m.add_wrapped(wrap_pyfunction!(stoer_wagner_min_cut))?;
m.add_wrapped(wrap_pyfunction!(graph_minimum_cycle_basis))?;
m.add_wrapped(wrap_pyfunction!(digraph_minimum_cycle_basis))?;
m.add_wrapped(wrap_pyfunction!(steiner_tree::steiner_tree))?;
m.add_wrapped(wrap_pyfunction!(digraph_dfs_search))?;
m.add_wrapped(wrap_pyfunction!(graph_dfs_search))?;
Expand Down

0 comments on commit d5304e6

Please sign in to comment.