Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Graph difference #571

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ the functions from the explicitly typed based on the data type.
retworkx.digraph_core_number
retworkx.digraph_complement
retworkx.digraph_union
retworkx.digraph_difference
retworkx.digraph_tensor_product
retworkx.digraph_cartesian_product
retworkx.digraph_random_layout
Expand Down Expand Up @@ -326,6 +327,7 @@ typed API based on the data type.
retworkx.graph_core_number
retworkx.graph_complement
retworkx.graph_union
retworkx.graph_difference
retworkx.graph_tensor_product
retworkx.graph_cartesian_product
retworkx.graph_random_layout
Expand Down
31 changes: 31 additions & 0 deletions releasenotes/notes/add-graph-difference-9916bf3d612f0b1a.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
---
features:
- |
Add two new functions which calculates the difference of two graphs :func:`~retworkx.graph_difference`
for undirected graphs and :func:`~retworkx.digraph_difference` for directed graphs. For example:

.. jupyter-execute::

import retworkx
from retworkx.visualization import mpl_draw

graph_1 = retworkx.PyGraph()
graph_1.add_nodes_from(["a_1", "a_2", "a_3", "a_4"])
graph_1.extend_from_weighted_edge_list([(0, 1, "e_1"),
(1, 2, "e_2"),
(2, 3, "e_3"),
(3, 0, "e_4"),
(0, 2, "e_5"),
(1, 3, "e_6"),
])
graph_2 = retworkx.PyGraph()
graph_2.add_nodes_from(["a_1", "a_2", "a_3", "a_4"])
graph_2.extend_from_weighted_edge_list([(0, 1, "e_1"),
(1, 2, "e_2"),
(2, 3, "e_3"),
(3, 0, "e_4"),
])

graph_difference = retworkx.graph_difference(graph_1, graph_2)

mpl_draw(graph_difference)
36 changes: 36 additions & 0 deletions retworkx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1814,6 +1814,42 @@ def _graph_cartesian_product(
return graph_cartesian_product(first, second)


@functools.singledispatch
def difference(
first,
second,
):
"""Return a new PyGraph that is the difference from two input
graph objects

:param first: The first graph object
:param second: The second graph object

:returns: A new graph object that is the difference of ``second`` and
``first``. It's worth noting the weight/data payload objects are
passed by reference from ``first`` to this new object.

:rtype: :class:`~retworkx.PyGraph` or :class:`~retworkx.PyDiGraph`
"""
raise TypeError("Invalid Input Type %s for graph" % type(first))


@difference.register(PyDiGraph)
def _digraph_difference(
first,
second,
):
return digraph_difference(first, second)


@difference.register(PyGraph)
def _graph_difference(
first,
second,
):
return graph_difference(first, second)


@functools.singledispatch
def bfs_search(graph, source, visitor):
"""Breadth-first traversal of a directed/undirected graph.
Expand Down
114 changes: 114 additions & 0 deletions src/difference.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.

use crate::{digraph, graph, StablePyGraph};

use hashbrown::HashSet;

use petgraph::visit::{EdgeRef, IntoEdgeReferences};
use petgraph::{algo, EdgeType};

use pyo3::exceptions::PyIndexError;
use pyo3::prelude::*;
use pyo3::Python;

fn difference<Ty: EdgeType>(
py: Python,
first: &StablePyGraph<Ty>,
second: &StablePyGraph<Ty>,
) -> PyResult<StablePyGraph<Ty>> {
let indexes_first = first.node_indices().collect::<HashSet<_>>();
let indexes_second = second.node_indices().collect::<HashSet<_>>();

if indexes_first != indexes_second {
return Err(PyIndexError::new_err(
"Node sets of the graphs should be equal",
));
}

let mut final_graph = StablePyGraph::<Ty>::with_capacity(
first.node_count(),
first.edge_count() - second.edge_count(),
);

for node in first.node_indices() {
let weight = &first[node];
final_graph.add_node(weight.clone_ref(py));
}
Comment on lines +43 to +46
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if the node weights are different between first and second? Do we care at all that if first[0] != second[0] we only preserve the payload for first[0]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should have a callback to handle weights from both graphs?


for e in first.edge_references() {
let has_edge = second.find_edge(e.source(), e.target());

match has_edge {
Some(_x) => continue,
None => final_graph.add_edge(e.source(), e.target(), e.weight().clone_ref(py)),
};
}

Ok(final_graph)
}

/// Return a new PyGraph that is the difference from two input
/// PyGraph objects
///
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the big thing missing here is the constraints on the input types (they have to have identical node indices). This probably should mention that and also maybe have an example and/or an explanation of how the difference is computed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! I'll add examples and explain how we compute the difference in the docs.

/// :param PyGraph first: The first undirected graph object
/// :param PyGraph second: The second undirected graph object
///
/// :returns: A new PyGraph object that is the difference of ``first``
/// and ``second``. It's worth noting the weight/data payload objects are
/// passed by reference from ``first`` graph to this new object.
///
/// :rtype: :class:`~retworkx.PyGraph`
#[pyfunction()]
#[pyo3(text_signature = "(first, second, /)")]
fn graph_difference(
py: Python,
first: &graph::PyGraph,
second: &graph::PyGraph,
) -> PyResult<graph::PyGraph> {
let out_graph = difference(py, &first.graph, &second.graph)?;

Ok(graph::PyGraph {
graph: out_graph,
multigraph: true,
node_removed: false,
})
}

/// Return a new PyDiGraph that is the difference from two input
/// PyGraph objects
///
/// :param PyGraph first: The first undirected graph object
/// :param PyGraph second: The second undirected graph object
///
/// :returns: A new PyDiGraph object that is the difference of ``first``
/// and ``second``. It's worth noting the weight/data payload objects are
/// passed by reference from ``first`` graph to this new object.
///
/// :rtype: :class:`~retworkx.PyDiGraph`
#[pyfunction()]
#[pyo3(text_signature = "(first, second, /)")]
fn digraph_difference(
py: Python,
first: &digraph::PyDiGraph,
second: &digraph::PyDiGraph,
) -> PyResult<digraph::PyDiGraph> {
let out_graph = difference(py, &first.graph, &second.graph)?;

Ok(digraph::PyDiGraph {
graph: out_graph,
cycle_state: algo::DfsSpace::default(),
check_cycle: false,
node_removed: false,
multigraph: true,
})
}
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ mod centrality;
mod coloring;
mod connectivity;
mod dag_algo;
mod difference;
mod digraph;
mod dot_utils;
mod generators;
Expand All @@ -38,6 +39,7 @@ use centrality::*;
use coloring::*;
use connectivity::*;
use dag_algo::*;
use difference::*;
use isomorphism::*;
use layout::*;
use matching::*;
Expand Down Expand Up @@ -325,6 +327,8 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(graph_union))?;
m.add_wrapped(wrap_pyfunction!(digraph_cartesian_product))?;
m.add_wrapped(wrap_pyfunction!(graph_cartesian_product))?;
m.add_wrapped(wrap_pyfunction!(digraph_difference))?;
m.add_wrapped(wrap_pyfunction!(graph_difference))?;
m.add_wrapped(wrap_pyfunction!(topological_sort))?;
m.add_wrapped(wrap_pyfunction!(descendants))?;
m.add_wrapped(wrap_pyfunction!(ancestors))?;
Expand Down
63 changes: 63 additions & 0 deletions tests/digraph/test_difference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import unittest
import retworkx


class TestDifference(unittest.TestCase):
def test_null_difference_null(self):
graph_1 = retworkx.PyDiGraph()
graph_2 = retworkx.PyDiGraph()

graph_difference = retworkx.digraph_difference(graph_1, graph_2)

self.assertEqual(graph_difference.num_nodes(), 0)
self.assertEqual(graph_difference.num_edges(), 0)

def test_difference_non_matching(self):
graph_1 = retworkx.generators.directed_path_graph(2)
graph_2 = retworkx.generators.directed_path_graph(3)

with self.assertRaises(IndexError):
_ = retworkx.digraph_difference(graph_1, graph_2)

def test_difference_weights_edges(self):
graph_1 = retworkx.PyDiGraph()
graph_1.add_nodes_from(["a_1", "a_2", "a_3", "a_4"])
graph_1.extend_from_weighted_edge_list(
[
(0, 1, "e_1"),
(1, 2, "e_2"),
(2, 3, "e_3"),
(3, 0, "e_4"),
(0, 2, "e_5"),
(1, 3, "e_6"),
]
)
graph_2 = retworkx.PyDiGraph()
graph_2.add_nodes_from(["a_1", "a_2", "a_3", "a_4"])
graph_2.extend_from_weighted_edge_list(
[
(0, 1, "e_1"),
(1, 2, "e_2"),
(2, 3, "e_3"),
(3, 0, "e_4"),
]
)

graph_difference = retworkx.digraph_difference(graph_1, graph_2)

expected_edges = [(0, 2, "e_5"), (1, 3, "e_6")]
self.assertEqual(graph_difference.num_nodes(), 4)
self.assertEqual(graph_difference.num_edges(), 2)
self.assertEqual(graph_difference.weighted_edge_list(), expected_edges)
63 changes: 63 additions & 0 deletions tests/graph/test_difference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import unittest
import retworkx


class TestDifference(unittest.TestCase):
def test_null_difference_null(self):
graph_1 = retworkx.PyGraph()
graph_2 = retworkx.PyGraph()

graph_difference = retworkx.graph_difference(graph_1, graph_2)

self.assertEqual(graph_difference.num_nodes(), 0)
self.assertEqual(graph_difference.num_edges(), 0)

def test_difference_non_matching(self):
graph_1 = retworkx.generators.path_graph(2)
graph_2 = retworkx.generators.path_graph(3)

with self.assertRaises(IndexError):
_ = retworkx.graph_difference(graph_1, graph_2)

def test_difference_weights(self):
graph_1 = retworkx.PyGraph()
graph_1.add_nodes_from(["a_1", "a_2", "a_3", "a_4"])
graph_1.extend_from_weighted_edge_list(
[
(0, 1, "e_1"),
(1, 2, "e_2"),
(2, 3, "e_3"),
(3, 0, "e_4"),
(0, 2, "e_5"),
(1, 3, "e_6"),
]
)
graph_2 = retworkx.PyGraph()
graph_2.add_nodes_from(["a_1", "a_2", "a_3", "a_4"])
graph_2.extend_from_weighted_edge_list(
[
(0, 1, "e_1"),
(1, 2, "e_2"),
(2, 3, "e_3"),
(3, 0, "e_4"),
]
)

graph_difference = retworkx.graph_difference(graph_1, graph_2)

expected_edges = [(0, 2, "e_5"), (1, 3, "e_6")]
self.assertEqual(graph_difference.num_nodes(), 4)
self.assertEqual(graph_difference.num_edges(), 2)
self.assertEqual(graph_difference.weighted_edge_list(), expected_edges)