Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transitive_closure_dag function #707

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api/algorithm_functions/dag_algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ DAG Algorithms
rustworkx.layers
rustworkx.transitive_reduction
rustworkx.topological_generations
rustworkx.transitive_closure_dag
1 change: 1 addition & 0 deletions docs/source/api/algorithm_functions/traversal.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ Traversal
rustworkx.visit.BFSVisitor
rustworkx.visit.DijkstraVisitor
rustworkx.TopologicalSorter
rustworkx.descendants_at_distance
1 change: 1 addition & 0 deletions docs/source/api/pydigraph_api_functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,4 @@ the functions from the explicitly typed based on the data type.
rustworkx.digraph_dijkstra_search
rustworkx.digraph_node_link_json
rustworkx.digraph_longest_simple_path
rustworkx.graph_descendants_at_distance
1 change: 1 addition & 0 deletions docs/source/api/pygraph_api_functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,4 @@ typed API based on the data type.
rustworkx.graph_dijkstra_search
rustworkx.graph_node_link_json
rustworkx.graph_longest_simple_path
rustworkx.digraph_descendants_at_distance
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
---
features:
- |
Added a new function ``descendants_at_distance`` to the rustworkx-core
crate under the ``traversal`` module
- |
Added a new function ``build_transitive_closure_dag`` to the rustworkx-core
crate under the ``traversal`` module.
- |
Added a new function, :func:`~.transitive_closure_dag`, which provides
an optimize method for computing the transitive closure of an input
DAG.
- |
Added a new function :func:`~.descendants_at_distance` which provides
a method to find the nodes at a fixed distance from a source in
a graph object.
44 changes: 44 additions & 0 deletions rustworkx-core/src/traversal/descendants.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.

use hashbrown::HashSet;
use petgraph::visit::{IntoNeighborsDirected, NodeCount, Visitable};

/// Returns all nodes at a fixed `distance` from `source` in `G`.
/// Args:
/// `graph`:
/// `source`:
/// `distance`:
pub fn descendants_at_distance<G>(graph: G, source: G::NodeId, distance: usize) -> Vec<G::NodeId>
where
G: Visitable + IntoNeighborsDirected + NodeCount,
G::NodeId: std::cmp::Eq + std::hash::Hash,
{
let mut current_layer: Vec<G::NodeId> = vec![source];
let mut layers: usize = 0;
let mut visited: HashSet<G::NodeId> = HashSet::with_capacity(graph.node_count());
visited.insert(source);
while !current_layer.is_empty() && layers < distance {
let mut next_layer: Vec<G::NodeId> = Vec::new();
for node in current_layer {
for child in graph.neighbors_directed(node, petgraph::Outgoing) {
if !visited.contains(&child) {
visited.insert(child);
next_layer.push(child);
}
}
}
current_layer = next_layer;
layers += 1;
}
current_layer
}
4 changes: 4 additions & 0 deletions rustworkx-core/src/traversal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
//! Module for graph traversal algorithms.

mod bfs_visit;
mod descendants;
mod dfs_edges;
mod dfs_visit;
mod dijkstra_visit;
mod transitive_closure;

use petgraph::prelude::*;
use petgraph::visit::GraphRef;
Expand All @@ -25,9 +27,11 @@ use petgraph::visit::VisitMap;
use petgraph::visit::Visitable;

pub use bfs_visit::{breadth_first_search, BfsEvent};
pub use descendants::descendants_at_distance;
pub use dfs_edges::dfs_edges;
pub use dfs_visit::{depth_first_search, DfsEvent};
pub use dijkstra_visit::{dijkstra_search, DijkstraEvent};
pub use transitive_closure::build_transitive_closure_dag;

/// Return if the expression is a break value, execute the provided statement
/// if it is a prune value.
Expand Down
83 changes: 83 additions & 0 deletions rustworkx-core/src/traversal/transitive_closure.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.

use petgraph::algo::{toposort, Cycle};
use petgraph::data::Build;
use petgraph::visit::{
GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCount, Visitable,
};

use crate::traversal::descendants_at_distance;

/// Build a transitive closure out of a given DAG
///
/// This function will mutate a given DAG object (which is typically moved to
/// this function) into a transitive closure of the graph and then returned.
/// If you'd like to preserve the input graph pass a clone of the original graph.
/// The transitive closure of :math:`G = (V, E)` is a graph :math:`G+ = (V, E+)`
/// such that for all pairs of :math:`v, w` in :math:`V` there is an edge
/// :math:`(v, w) in :math:`E+` if and only if there is a non-null path from
/// :math:`v` to :math:`w` in :math:`G`. This funciton provides an optimized
/// path for computing the the transitive closure of a DAG, if the input graph
/// contains cycles it will error.
///
/// Arguments:
///
/// - `graph`: A mutable graph object representing the DAG
/// - `topological_order`: An optional `Vec` of node identifiers representing
/// the topological order to traverse the DAG with. If not specified the
/// `petgraph::algo::toposort` function will be called to generate this
/// - `default_edge_weight`: A callable function that takes no arguments and
/// returns the `EdgeWeight` type object to use for each edge added to
/// `graph
///
/// # Example
///
/// ```rust
/// use rustworkx_core::traversal::build_transitive_closure_dag;
///
/// let g = petgraph::graph::DiGraph::<i32, i32>::from_edges(&[(0, 1, 0), (1, 2, 0), (2, 3, 0)]);
///
/// let res = build_transitive_closure_dag(g, None, || -> i32 {0});
/// let out_graph = res.unwrap();
/// let out_edges: Vec<(usize, usize)> = out_graph
/// .edge_indices()
/// .map(|e| {
/// let endpoints = out_graph.edge_endpoints(e).unwrap();
/// (endpoints.0.index(), endpoints.1.index())
/// })
/// .collect();
/// assert_eq!(vec![(0, 1), (1, 2), (2, 3), (1, 3), (0, 3), (0, 2)], out_edges)
/// ```
pub fn build_transitive_closure_dag<'a, G, F>(
mut graph: G,
topological_order: Option<Vec<G::NodeId>>,
default_edge_weight: F,
) -> Result<G, Cycle<G::NodeId>>
where
G: NodeCount + Build + Clone,
for<'b> &'b G:
GraphBase<NodeId = G::NodeId> + Visitable + IntoNeighborsDirected + IntoNodeIdentifiers,
G::NodeId: std::cmp::Eq + std::hash::Hash,
F: Fn() -> G::EdgeWeight,
{
let node_order: Vec<G::NodeId> = match topological_order {
Some(topo_order) => topo_order,
None => toposort(&graph, None)?,
};
for node in node_order.into_iter().rev() {
for descendant in descendants_at_distance(&graph, node, 2) {
graph.add_edge(node, descendant, default_edge_weight());
}
}
Ok(graph)
}
25 changes: 25 additions & 0 deletions rustworkx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1875,6 +1875,30 @@ def longest_simple_path(graph):
"""


@_rustworkx_dispatch
def descendants_at_distance(graph, source, distance):
"""Returns all nodes at a fixed distance from ``source`` in ``graph``

:param graph: The graph to find the descendants in
:param int source: The node index to find the descendants from
:param int distance: The distance from ``source``

:returns: The node indices of the nodes ``distance`` from ``source`` in ``graph``.
:rtype: NodeIndices

For example::

import rustworkx as rx

graph = rx.generators.path_graph(5)
res = rx.descendants_at_distance(graph, 2, 2)
print(res)

will return: ``[0, 4]``
"""
raise TypeError("Invalid Input Type %s for graph" % type(graph))


@_rustworkx_dispatch
def isolates(graph):
"""Return a list of isolates in a graph object
Expand Down Expand Up @@ -2006,3 +2030,4 @@ def all_shortest_paths(

"""
raise TypeError("Invalid Input Type %s for graph" % type(graph))
>>>>>>> origin/main
45 changes: 44 additions & 1 deletion src/dag_algo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ pub fn traversal_directions(reverse: bool) -> (petgraph::Direction, petgraph::Di
}
}

use rustworkx_core::traversal;

/// Find the longest path in a DAG
///
/// :param PyDiGraph graph: The graph to find the longest path on. The input
Expand Down Expand Up @@ -597,6 +599,48 @@ pub fn collect_bicolor_runs(
Ok(block_list)
}

/// Return the transitive closure of a directed acyclic graph
///
/// The transitive closure of :math:`G = (V, E)` is a graph :math:`G+ = (V, E+)`
/// such that for all pairs of :math:`v, w` in :math:`V` there is an edge
/// :math:`(v, w) in :math:`E+` if and only if there is a non-null path from
/// :math:`v` to :math:`w` in :math:`G`.
///
/// :param PyDiGraph graph: The input DAG to compute the transitive closure of
/// :param list topological_order: An optional topological order for ``graph``
/// which represents the order the graph will be traversed in computing
/// the transitive closure. If one is not provided (or it is explicitly
/// set to ``None``) a topological order will be computed by this function.
///
/// :returns: The transitive closure of ``graph``
/// :rtype: PyDiGraph
///
/// :raises DAGHasCycle: If the input ``graph`` is not acyclic
#[pyfunction]
#[pyo3(text_signature = "(graph, / topological_order=None)")]
pub fn transitive_closure_dag(
mtreinish marked this conversation as resolved.
Show resolved Hide resolved
py: Python,
graph: &digraph::PyDiGraph,
topological_order: Option<Vec<usize>>,
) -> PyResult<digraph::PyDiGraph> {
let default_weight = || -> PyObject { py.None() };
match traversal::build_transitive_closure_dag(
graph.graph.clone(),
topological_order.map(|order| order.into_iter().map(NodeIndex::new).collect()),
default_weight,
) {
Ok(out_graph) => Ok(digraph::PyDiGraph {
graph: out_graph,
cycle_state: algo::DfsSpace::default(),
check_cycle: false,
node_removed: false,
multigraph: true,
attrs: py.None(),
}),
Err(_err) => Err(DAGHasCycle::new_err("Topological Sort encountered a cycle")),
}
}

/// Returns the transitive reduction of a directed acyclic graph
///
/// The transitive reduction of :math:`G = (V,E)` is a graph :math:`G\prime = (V,E\prime)`
Expand All @@ -612,7 +656,6 @@ pub fn collect_bicolor_runs(
/// :rtype: Tuple[PyGraph, dict]
///
/// :raises PyValueError: if ``graph`` is not a DAG

#[pyfunction]
#[pyo3(text_signature = "(graph, /)")]
pub fn transitive_reduction(
Expand Down
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,9 @@ fn rustworkx(py: Python<'_>, m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(read_graphml))?;
m.add_wrapped(wrap_pyfunction!(digraph_node_link_json))?;
m.add_wrapped(wrap_pyfunction!(graph_node_link_json))?;
m.add_wrapped(wrap_pyfunction!(transitive_closure_dag))?;
m.add_wrapped(wrap_pyfunction!(graph_descendants_at_distance))?;
m.add_wrapped(wrap_pyfunction!(digraph_descendants_at_distance))?;
m.add_wrapped(wrap_pyfunction!(from_node_link_json_file))?;
m.add_wrapped(wrap_pyfunction!(parse_node_link_json))?;
m.add_wrapped(wrap_pyfunction!(pagerank))?;
Expand Down
66 changes: 65 additions & 1 deletion src/traversal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use dijkstra_visit::{dijkstra_handler, PyDijkstraVisitor};
use rustworkx_core::traversal::{
ancestors as core_ancestors, bfs_predecessors as core_bfs_predecessors,
bfs_successors as core_bfs_successors, breadth_first_search, depth_first_search,
descendants as core_descendants, dfs_edges, dijkstra_search,
descendants as core_descendants, descendants_at_distance, dfs_edges, dijkstra_search,
};

use super::{digraph, graph, iterators, CostFn};
Expand Down Expand Up @@ -773,3 +773,67 @@ pub fn graph_dijkstra_search(

Ok(())
}

/// Returns all nodes at a fixed distance from ``source`` in ``graph``
///
/// :param PyGraph graph: The graph to find the descendants in
/// :param int source: The node index to find the descendants from
/// :param int distance: The distance from ``source``
///
/// :returns: The node indices of the nodes ``distance`` from ``source`` in ``graph``.
/// :rtype: NodeIndices
/// For example::
///
/// import rustworkx as rx
///
/// graph = rx.generators.path_graph(5)
/// res = rx.descendants_at_distance(graph, 2, 2)
/// print(res)
///
/// will return: ``[0, 4]``
#[pyfunction]
pub fn graph_descendants_at_distance(
graph: graph::PyGraph,
source: usize,
distance: usize,
) -> iterators::NodeIndices {
let source = NodeIndex::new(source);
iterators::NodeIndices {
nodes: descendants_at_distance(&graph.graph, source, distance)
.into_iter()
.map(|x| x.index())
.collect(),
}
}

/// Returns all nodes at a fixed distance from ``source`` in ``graph``
///
/// :param PyDiGraph graph: The graph to find the descendants in
/// :param int source: The node index to find the descendants from
/// :param int distance: The distance from ``source``
///
/// :returns: The node indices of the nodes ``distance`` from ``source`` in ``graph``.
/// :rtype: NodeIndices
/// For example::
///
/// import rustworkx as rx
///
/// graph = rx.generators.directed_path_graph(5)
/// res = rx.descendants_at_distance(graph, 2, 2)
/// print(res)
///
/// will return: ``[4]``
#[pyfunction]
pub fn digraph_descendants_at_distance(
graph: digraph::PyDiGraph,
source: usize,
distance: usize,
) -> iterators::NodeIndices {
let source = NodeIndex::new(source);
iterators::NodeIndices {
nodes: descendants_at_distance(&graph.graph, source, distance)
.into_iter()
.map(|x| x.index())
.collect(),
}
}
Loading
Loading