Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transitive_closure_dag function #707

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ Traversal
rustworkx.visit.BFSVisitor
rustworkx.visit.DijkstraVisitor
rustworkx.TopologicalSorter
rustworkx.descendants_at_distance

.. _dag-algorithms:

Expand All @@ -94,6 +95,7 @@ DAG Algorithms
rustworkx.dag_weighted_longest_path_length
rustworkx.is_directed_acyclic_graph
rustworkx.layers
rustworkx.transitive_closure_dag

.. _tree:

Expand Down Expand Up @@ -325,6 +327,7 @@ the functions from the explicitly typed based on the data type.
rustworkx.digraph_bfs_search
rustworkx.digraph_dijkstra_search
rustworkx.digraph_node_link_json
rustworkx.digraph_descendants_at_distance

.. _api-functions-pygraph:

Expand Down Expand Up @@ -379,6 +382,7 @@ typed API based on the data type.
rustworkx.graph_bfs_search
rustworkx.graph_dijkstra_search
rustworkx.graph_node_link_json
rustworkx.graph_descendants_at_distance

Exceptions
==========
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
---
features:
- |
Added a new function ``descendants_at_distance`` to the rustworkx-core
crate under the ``traversal`` module
- |
Added a new function, :func:`~.transitive_closure_dag`, which provides
an optimize method for computing the transitive closure of an input
DAG.
- |
Added a new function :func:`~.descendants_at_distance` which provides
a method to find the nodes at a fixed distance from a source in
a graph object.
44 changes: 44 additions & 0 deletions rustworkx-core/src/traversal/descendants.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.

use hashbrown::HashSet;
use petgraph::visit::{IntoNeighborsDirected, NodeCount, Visitable};

/// Returns all nodes at a fixed `distance` from `source` in `G`.
/// Args:
/// `graph`:
/// `source`:
/// `distance`:
pub fn descendants_at_distance<G>(graph: G, source: G::NodeId, distance: usize) -> Vec<G::NodeId>
where
G: Visitable + IntoNeighborsDirected + NodeCount,
G::NodeId: std::cmp::Eq + std::hash::Hash,
{
let mut current_layer: Vec<G::NodeId> = vec![source];
let mut layers: usize = 0;
let mut visited: HashSet<G::NodeId> = HashSet::with_capacity(graph.node_count());
visited.insert(source);
while !current_layer.is_empty() && layers < distance {
let mut next_layer: Vec<G::NodeId> = Vec::new();
for node in current_layer {
for child in graph.neighbors_directed(node, petgraph::Outgoing) {
if !visited.contains(&child) {
visited.insert(child);
next_layer.push(child);
}
}
}
current_layer = next_layer;
layers += 1;
}
current_layer
}
2 changes: 2 additions & 0 deletions rustworkx-core/src/traversal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
//! Module for graph traversal algorithms.

mod bfs_visit;
mod descendants;
mod dfs_edges;
mod dfs_visit;
mod dijkstra_visit;

pub use bfs_visit::{breadth_first_search, BfsEvent};
pub use descendants::descendants_at_distance;
pub use dfs_edges::dfs_edges;
pub use dfs_visit::{depth_first_search, DfsEvent};
pub use dijkstra_visit::{dijkstra_search, DijkstraEvent};
Expand Down
34 changes: 34 additions & 0 deletions rustworkx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2382,3 +2382,37 @@ def _graph_node_link_json(graph, path=None, graph_attrs=None, node_attrs=None, e
return graph_node_link_json(
graph, path=path, graph_attrs=graph_attrs, node_attrs=node_attrs, edge_attrs=edge_attrs
)


@functools.singledispatch
def descendants_at_distance(graph, source, distance):
"""Returns all nodes at a fixed distance from ``source`` in ``graph``

:param graph: The graph to find the descendants in
:param int source: The node index to find the descendants from
:param int distance: The distance from ``source``

:returns: The node indices of the nodes ``distance`` from ``source`` in ``graph``.
:rtype: NodeIndices

For example::

import rustworkx as rx

graph = rx.generators.path_graph(5)
res = rx.descendants_at_distance(graph, 2, 2)
print(res)

will return: ``[0, 4]``
"""
raise TypeError("Invalid Input Type %s for graph" % type(graph))


@descendants_at_distance.register(PyDiGraph)
def _digraph_descendants_at_distance(graph, source, distance):
return digraph_descendants_at_distance(graph, source, distance)


@descendants_at_distance.register(PyGraph)
def _graph_descendants_at_distance(graph, source, distance):
return graph_descendants_at_distance(graph, source, distance)
49 changes: 49 additions & 0 deletions src/dag_algo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ use petgraph::graph::NodeIndex;
use petgraph::prelude::*;
use petgraph::visit::NodeCount;

use rustworkx_core::traversal::descendants_at_distance;

/// Find the longest path in a DAG
///
/// :param PyDiGraph graph: The graph to find the longest path on. The input
Expand Down Expand Up @@ -634,3 +636,50 @@ pub fn collect_bicolor_runs(

Ok(block_list)
}

/// Return the transitive closure of a directed acyclic graph
///
/// The transitive closure of :math:`G = (V, E)` is a graph :math:`G+ = (V, E+)`
/// such that for all pairs of :math:`v, w` in :math:`V` there is an edge
/// :math:`(v, w) in :math:`E+` if and only if there is a non-null path from
/// :math:`v` to :math:`w` in :math:`G`.
///
/// :param PyDiGraph graph: The input DAG to compute the transitive closure of
/// :param list topological_order: An optional topological order for ``graph``
/// which represents the order the graph will be traversed in computing
/// the transitive closure. If one is not provided (or it is explicitly
/// set to ``None``) a topological order will be computed by this function.
///
/// :returns: The transitive closure of ``graph``
/// :rtype: PyDiGraph
///
/// :raises DAGHasCycle: If the input ``graph`` is not acyclic
#[pyfunction]
#[pyo3(text_signature = "(graph, / topological_order=None)")]
pub fn transitive_closure_dag(
mtreinish marked this conversation as resolved.
Show resolved Hide resolved
py: Python,
graph: &digraph::PyDiGraph,
topological_order: Option<Vec<usize>>,
) -> PyResult<digraph::PyDiGraph> {
let node_order: Vec<NodeIndex> = match topological_order {
Some(topo_order) => topo_order.into_iter().map(NodeIndex::new).collect(),
None => match algo::toposort(&graph.graph, None) {
Ok(nodes) => nodes,
Err(_err) => return Err(DAGHasCycle::new_err("Topological Sort encountered a cycle")),
},
};
let mut out_graph = graph.graph.clone();
for node in node_order.into_iter().rev() {
for descendant in descendants_at_distance(&out_graph, node, 2) {
out_graph.add_edge(node, descendant, py.None());
}
}
Ok(digraph::PyDiGraph {
graph: out_graph,
cycle_state: algo::DfsSpace::default(),
check_cycle: false,
node_removed: false,
multigraph: true,
attrs: py.None(),
})
}
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,9 @@ fn rustworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(read_graphml))?;
m.add_wrapped(wrap_pyfunction!(digraph_node_link_json))?;
m.add_wrapped(wrap_pyfunction!(graph_node_link_json))?;
m.add_wrapped(wrap_pyfunction!(transitive_closure_dag))?;
m.add_wrapped(wrap_pyfunction!(graph_descendants_at_distance))?;
m.add_wrapped(wrap_pyfunction!(digraph_descendants_at_distance))?;
m.add_class::<digraph::PyDiGraph>()?;
m.add_class::<graph::PyGraph>()?;
m.add_class::<toposort::TopologicalSorter>()?;
Expand Down
66 changes: 65 additions & 1 deletion src/traversal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use dfs_visit::{dfs_handler, PyDfsVisitor};
use dijkstra_visit::{dijkstra_handler, PyDijkstraVisitor};

use rustworkx_core::traversal::{
breadth_first_search, depth_first_search, dfs_edges, dijkstra_search,
breadth_first_search, depth_first_search, descendants_at_distance, dfs_edges, dijkstra_search,
};

use super::{digraph, graph, iterators, CostFn};
Expand Down Expand Up @@ -707,3 +707,67 @@ pub fn graph_dijkstra_search(

Ok(())
}

/// Returns all nodes at a fixed distance from ``source`` in ``graph``
///
/// :param PyGraph graph: The graph to find the descendants in
/// :param int source: The node index to find the descendants from
/// :param int distance: The distance from ``source``
///
/// :returns: The node indices of the nodes ``distance`` from ``source`` in ``graph``.
/// :rtype: NodeIndices
/// For example::
///
/// import rustworkx as rx
///
/// graph = rx.generators.path_graph(5)
/// res = rx.descendants_at_distance(graph, 2, 2)
/// print(res)
///
/// will return: ``[0, 4]``
#[pyfunction]
pub fn graph_descendants_at_distance(
graph: graph::PyGraph,
source: usize,
distance: usize,
) -> iterators::NodeIndices {
let source = NodeIndex::new(source);
iterators::NodeIndices {
nodes: descendants_at_distance(&graph.graph, source, distance)
.into_iter()
.map(|x| x.index())
.collect(),
}
}

/// Returns all nodes at a fixed distance from ``source`` in ``graph``
///
/// :param PyDiGraph graph: The graph to find the descendants in
/// :param int source: The node index to find the descendants from
/// :param int distance: The distance from ``source``
///
/// :returns: The node indices of the nodes ``distance`` from ``source`` in ``graph``.
/// :rtype: NodeIndices
/// For example::
///
/// import rustworkx as rx
///
/// graph = rx.generators.directed_path_graph(5)
/// res = rx.descendants_at_distance(graph, 2, 2)
/// print(res)
///
/// will return: ``[4]``
#[pyfunction]
pub fn digraph_descendants_at_distance(
graph: digraph::PyDiGraph,
source: usize,
distance: usize,
) -> iterators::NodeIndices {
let source = NodeIndex::new(source);
iterators::NodeIndices {
nodes: descendants_at_distance(&graph.graph, source, distance)
.into_iter()
.map(|x| x.index())
.collect(),
}
}
33 changes: 33 additions & 0 deletions tests/rustworkx_tests/digraph/test_transitive_closure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import unittest

import rustworkx as rx


class TestTransitivity(unittest.TestCase):
def test_path_graph(self):
graph = rx.generators.directed_path_graph(4)
transitive_closure = rx.transitive_closure_dag(graph)
expected_edge_list = [(0, 1), (1, 2), (2, 3), (1, 3), (0, 3), (0, 2)]
self.assertEqual(transitive_closure.edge_list(), expected_edge_list)

def test_invalid_type(self):
with self.assertRaises(TypeError):
rx.transitive_closure_dag(rx.PyGraph())

def test_cycle_error(self):
graph = rx.PyDiGraph()
graph.extend_from_edge_list([(0, 1), (1, 0)])
with self.assertRaises(rx.DAGHasCycle):
rx.transitive_closure_dag(graph)