Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new subsitute_subgraph() method to graph classes #823

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- |
Added new methods, :meth:`.PyDiGraph.subsitute_subgraph` and
:meth:`.PyGraph.substitute_subgraph`, which is used to replace
a subgraph in a graph object with an external graph.
122 changes: 121 additions & 1 deletion src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::io::{BufReader, BufWriter};
use std::str;

use hashbrown::{HashMap, HashSet};
use indexmap::IndexSet;
use indexmap::{IndexMap, IndexSet};

use rustworkx_core::dictmap::*;

Expand Down Expand Up @@ -2526,6 +2526,126 @@ impl PyDiGraph {
}
}

/// Substitute a subgraph in the graph with a different subgraph
///
/// This is used to replace a subgraph in this graph with another graph. A similar result
/// can be achieved by combining :meth:`~.PyDiGraph.contract_nodes` and
/// :meth:`~.PyDiGraph.substitute_node_with_subgraph`.
///
/// :param list nodes: A list of nodes in this graph representing the subgraph
/// to be removed.
/// :param PyDiGraph subgraph: The subgraph to replace ``nodes`` with
mtreinish marked this conversation as resolved.
Show resolved Hide resolved
/// :param dict input_node_map: The mapping of node indices from ```nodes`` to a node
mtreinish marked this conversation as resolved.
Show resolved Hide resolved
/// in ``subgraph``. This is used for incoming and outgoing edges into the removed
/// subgraph. This will replace any edges conneted to a node in ``nodes`` with the
mtreinish marked this conversation as resolved.
Show resolved Hide resolved
/// other endpoint outside ``nodes`` where the node in ``nodes`` replaced via this
/// mapping.
/// :param callable edge_weight_map: An optional callable object that when
/// used will receive an edge's weight/data payload from ``subgraph`` and
/// will return an object to use as the weight for a newly created edge
/// after the edge is mapped from ``other``. If not specified the weight
/// from the edge in ``other`` will be copied by reference and used.
///
/// :returns: A mapping of node indices in ``other`` to the new node index in this graph
/// :rtype: NodeMap
pub fn substitute_subgraph(
mtreinish marked this conversation as resolved.
Show resolved Hide resolved
&mut self,
py: Python,
nodes: Vec<usize>,
other: &PyDiGraph,
input_node_map: HashMap<usize, usize>,
edge_weight_map: Option<PyObject>,
) -> PyResult<NodeMap> {
let mut io_nodes: Vec<(NodeIndex, NodeIndex, PyObject)> = Vec::new();
let mut node_map: IndexMap<usize, usize, ahash::RandomState> =
IndexMap::with_capacity_and_hasher(
other.graph.node_count(),
ahash::RandomState::default(),
);
let removed_nodes: HashSet<NodeIndex> = nodes.iter().map(|n| NodeIndex::new(*n)).collect();

let weight_map_fn = |obj: &PyObject, weight_fn: &Option<PyObject>| -> PyResult<PyObject> {
match weight_fn {
Some(weight_fn) => weight_fn.call1(py, (obj,)),
None => Ok(obj.clone_ref(py)),
}
};
for node in nodes {
let index = NodeIndex::new(node);
io_nodes.extend(
self.graph
.edges_directed(index, petgraph::Direction::Incoming)
.filter_map(|edge| {
if !removed_nodes.contains(&edge.source()) {
Some((edge.source(), edge.target(), edge.weight().clone_ref(py)))
} else {
None
}
}),
);
io_nodes.extend(
self.graph
.edges_directed(index, petgraph::Direction::Outgoing)
.filter_map(|edge| {
if !removed_nodes.contains(&edge.target()) {
Some((edge.source(), edge.target(), edge.weight().clone_ref(py)))
} else {
None
}
}),
);
self.graph.remove_node(index);
}
for node in other.graph.node_indices() {
let weight = other.graph.node_weight(node).unwrap();
let new_index = self.graph.add_node(weight.clone_ref(py));
node_map.insert(node.index(), new_index.index());
}
for edge in other.graph.edge_references() {
let new_source = node_map[edge.source().index()];
let new_target = node_map[edge.target().index()];
self.graph.add_edge(
NodeIndex::new(new_source),
NodeIndex::new(new_target),
weight_map_fn(edge.weight(), &edge_weight_map)?,
);
}
for edge in io_nodes {
let old_source = edge.0;
let new_source = if removed_nodes.contains(&old_source) {
mtreinish marked this conversation as resolved.
Show resolved Hide resolved
match input_node_map.get(&old_source.index()) {
Some(new_source) => NodeIndex::new(node_map[new_source]),
None => {
let missing_index = old_source.index();
return Err(PyIndexError::new_err(format!(
"Input/Output node {} not found in io_node_map",
missing_index
)));
}
}
} else {
old_source
};
let old_target = edge.1;
let new_target = if removed_nodes.contains(&old_target) {
match input_node_map.get(&old_target.index()) {
Some(new_target) => NodeIndex::new(node_map[new_target]),
None => {
let missing_index = old_target.index();
return Err(PyIndexError::new_err(format!(
"Input/Output node {} not found in io_node_map",
missing_index
)));
}
}
} else {
old_target
};
self.graph.add_edge(new_source, new_target, edge.2);
}
Ok(NodeMap { node_map })
}

/// Return a new PyDiGraph object for an edge induced subgraph of this graph
///
/// The induced subgraph contains each edge in `edge_list` and each node
Expand Down
111 changes: 109 additions & 2 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::io::{BufReader, BufWriter};
use std::str;

use hashbrown::{HashMap, HashSet};
use indexmap::IndexSet;
use indexmap::{IndexMap, IndexSet};
use rustworkx_core::dictmap::*;

use pyo3::exceptions::PyIndexError;
Expand All @@ -36,7 +36,9 @@ use numpy::Complex64;
use numpy::PyReadonlyArray2;

use super::dot_utils::build_dot;
use super::iterators::{EdgeIndexMap, EdgeIndices, EdgeList, NodeIndices, WeightedEdgeList};
use super::iterators::{
EdgeIndexMap, EdgeIndices, EdgeList, NodeIndices, NodeMap, WeightedEdgeList,
};
use super::{
find_node_by_weight, merge_duplicates, weight_callable, IsNan, NoEdgeBetweenNodes,
NodesRemoved, StablePyGraph,
Expand Down Expand Up @@ -1712,6 +1714,111 @@ impl PyGraph {
out_graph
}

/// Substitute a subgraph in the graph with a different subgraph
///
/// :param list nodes: A list of nodes in this graph representing the subgraph
/// to be removed.
/// :param PyDiGraph subgraph: The subgraph to replace ``nodes`` with
/// :param dict input_node_map: The mapping of node indices from ```nodes`` to a node
/// in ``subgraph``. This is used for incoming and outgoing edges into the removed
/// subgraph. This will replace any edges conneted to a node in ``nodes`` with the
/// other endpoint outside ``nodes`` where the node in ``nodes`` replaced via this
/// mapping.
/// :param callable edge_weight_map: An optional callable object that when
/// used will receive an edge's weight/data payload from ``subgraph`` and
/// will return an object to use as the weight for a newly created edge
/// after the edge is mapped from ``other``. If not specified the weight
/// from the edge in ``other`` will be copied by reference and used.
///
/// :returns: A mapping of node indices in ``other`` to the new node index in this graph
/// :rtype: NodeMap
pub fn substitute_subgraph(
&mut self,
py: Python,
nodes: Vec<usize>,
other: &PyGraph,
input_node_map: HashMap<usize, usize>,
edge_weight_map: Option<PyObject>,
) -> PyResult<NodeMap> {
let mut io_nodes: Vec<(NodeIndex, NodeIndex, PyObject)> = Vec::new();
let mut node_map: IndexMap<usize, usize, ahash::RandomState> =
IndexMap::with_capacity_and_hasher(
other.graph.node_count(),
ahash::RandomState::default(),
);
let removed_nodes: HashSet<NodeIndex> = nodes.iter().map(|n| NodeIndex::new(*n)).collect();

let weight_map_fn = |obj: &PyObject, weight_fn: &Option<PyObject>| -> PyResult<PyObject> {
match weight_fn {
Some(weight_fn) => weight_fn.call1(py, (obj,)),
None => Ok(obj.clone_ref(py)),
}
};
for node in nodes {
let index = NodeIndex::new(node);
io_nodes.extend(
self.graph
.edges_directed(index, petgraph::Direction::Outgoing)
.filter_map(|edge| {
if !removed_nodes.contains(&edge.target()) {
Some((edge.source(), edge.target(), edge.weight().clone_ref(py)))
} else {
None
}
}),
);
self.graph.remove_node(index);
}
for node in other.graph.node_indices() {
let weight = other.graph.node_weight(node).unwrap();
let new_index = self.graph.add_node(weight.clone_ref(py));
node_map.insert(node.index(), new_index.index());
}
for edge in other.graph.edge_references() {
let new_source = node_map[edge.source().index()];
let new_target = node_map[edge.target().index()];
self.graph.add_edge(
NodeIndex::new(new_source),
NodeIndex::new(new_target),
weight_map_fn(edge.weight(), &edge_weight_map)?,
);
}
for edge in io_nodes {
let old_source = edge.0;
let new_source = if removed_nodes.contains(&old_source) {
match input_node_map.get(&old_source.index()) {
Some(new_source) => NodeIndex::new(node_map[new_source]),
None => {
let missing_index = old_source.index();
return Err(PyIndexError::new_err(format!(
"Input/Output node {} not found in io_node_map",
missing_index
)));
}
}
} else {
old_source
};
let old_target = edge.1;
let new_target = if removed_nodes.contains(&old_target) {
match input_node_map.get(&old_target.index()) {
Some(new_target) => NodeIndex::new(node_map[new_target]),
None => {
let missing_index = old_target.index();
return Err(PyIndexError::new_err(format!(
"Input/Output node {} not found in io_node_map",
missing_index
)));
}
}
} else {
old_target
};
self.graph.add_edge(new_source, new_target, edge.2);
}
Ok(NodeMap { node_map })
}

/// Return a shallow copy of the graph
///
/// All node and edge weight/data payloads in the copy will have a
Expand Down
56 changes: 56 additions & 0 deletions tests/rustworkx_tests/digraph/test_substitute_subgraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import unittest
import rustworkx


class TestSubstitute(unittest.TestCase):
def setUp(self):
super().setUp()
self.graph = rustworkx.generators.directed_path_graph(5)

def test_empty_replacement(self):
in_graph = rustworkx.PyDiGraph()
with self.assertRaises(IndexError):
self.graph.substitute_subgraph([2], in_graph, {})

def test_single_node(self):
in_graph = rustworkx.PyDiGraph()
in_graph.add_node(0)
in_graph.add_child(0, 1, "edge")
res = self.graph.substitute_subgraph([2], in_graph, {2: 0})
self.assertEqual([(0, 1), (2, 5), (1, 2), (3, 4), (2, 3)], self.graph.edge_list())
self.assertEqual("edge", self.graph.get_edge_data(2, 5))
self.assertEqual(res, {0: 2, 1: 5})

def test_edge_weight_modifier(self):
in_graph = rustworkx.PyDiGraph()
in_graph.add_node(0)
in_graph.add_child(0, 1, "edge")
res = self.graph.substitute_subgraph(
[2],
in_graph,
{2: 0},
edge_weight_map=lambda edge: edge + "-migrated",
)
self.assertEqual([(0, 1), (2, 5), (1, 2), (3, 4), (2, 3)], self.graph.edge_list())
self.assertEqual("edge-migrated", self.graph.get_edge_data(2, 5))
self.assertEqual(res, {0: 2, 1: 5})

def test_multiple_mapping(self):
graph = rustworkx.generators.directed_star_graph(5)
in_graph = rustworkx.generators.directed_star_graph(3, inward=True)
res = graph.substitute_subgraph([0, 1, 2], in_graph, {0: 0, 1: 1, 2: 2})
self.assertEqual({0: 2, 1: 1, 2: 0}, res)
expected = [(1, 2), (0, 2), (2, 4), (2, 3)]
self.assertEqual(expected, graph.edge_list())
58 changes: 58 additions & 0 deletions tests/rustworkx_tests/graph/test_substitute_subgraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import unittest
import rustworkx


class TestSubstitute(unittest.TestCase):
def setUp(self):
super().setUp()
self.graph = rustworkx.generators.path_graph(5)

def test_empty_replacement(self):
in_graph = rustworkx.PyGraph()
with self.assertRaises(IndexError):
self.graph.substitute_subgraph([2], in_graph, {})

def test_single_node(self):
in_graph = rustworkx.PyGraph()
in_graph.add_node(0)
in_graph.add_node(1)
in_graph.add_edge(0, 1, "edge")
res = self.graph.substitute_subgraph([2], in_graph, {2: 0})
self.assertEqual([(0, 1), (2, 5), (2, 3), (3, 4), (2, 1)], self.graph.edge_list())
self.assertEqual("edge", self.graph.get_edge_data(2, 5))
self.assertEqual(res, {0: 2, 1: 5})

def test_edge_weight_modifier(self):
in_graph = rustworkx.PyGraph()
in_graph.add_node(0)
in_graph.add_node(1)
in_graph.add_edge(0, 1, "edge")
res = self.graph.substitute_subgraph(
[2],
in_graph,
{2: 0},
edge_weight_map=lambda edge: edge + "-migrated",
)
self.assertEqual([(0, 1), (2, 5), (2, 3), (3, 4), (2, 1)], self.graph.edge_list())
self.assertEqual("edge-migrated", self.graph.get_edge_data(2, 5))
self.assertEqual(res, {0: 2, 1: 5})

def test_multiple_mapping(self):
graph = rustworkx.generators.star_graph(5)
in_graph = rustworkx.generators.path_graph(3)
res = graph.substitute_subgraph([0, 1, 2], in_graph, {0: 0, 1: 1, 2: 2})
self.assertEqual({0: 2, 1: 1, 2: 0}, res)
expected = [(2, 1), (1, 0), (2, 4), (2, 3)]
self.assertEqual(expected, graph.edge_list())