diff --git a/src/dag_algo/lexicographical_sort.rs b/src/dag_algo/lexicographical_sort.rs new file mode 100644 index 000000000..54e13c5df --- /dev/null +++ b/src/dag_algo/lexicographical_sort.rs @@ -0,0 +1,97 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +use std::cmp::Ordering; +use std::collections::BinaryHeap; + +use hashbrown::HashMap; + +use pyo3::prelude::*; +use pyo3::Python; + +use petgraph::graph::NodeIndex; +use petgraph::visit::NodeCount; + +use crate::digraph; + +pub fn lexicographical_topological_sort( + py: Python, + dag: &digraph::PyDiGraph, + key: PyObject, +) -> PyResult> { + let key_callable = |a: &PyObject| -> PyResult { + let res = key.call1(py, (a,))?; + Ok(res.to_object(py)) + }; + // HashMap of node_index indegree + let node_count = dag.node_count(); + let mut in_degree_map: HashMap = HashMap::with_capacity(node_count); + for node in dag.graph.node_indices() { + in_degree_map.insert(node, dag.in_degree(node.index())); + } + + #[derive(Clone, Eq, PartialEq)] + struct State { + key: String, + node: NodeIndex, + } + + impl Ord for State { + fn cmp(&self, other: &State) -> Ordering { + // Notice that the we flip the ordering on costs. + // In case of a tie we compare positions - this step is necessary + // to make implementations of `PartialEq` and `Ord` consistent. + other + .key + .cmp(&self.key) + .then_with(|| other.node.index().cmp(&self.node.index())) + } + } + + // `PartialOrd` needs to be implemented as well. + impl PartialOrd for State { + fn partial_cmp(&self, other: &State) -> Option { + Some(self.cmp(other)) + } + } + let mut zero_indegree = BinaryHeap::with_capacity(node_count); + for (node, degree) in in_degree_map.iter() { + if *degree == 0 { + let map_key_raw = key_callable(&dag.graph[*node])?; + let map_key: String = map_key_raw.extract(py)?; + zero_indegree.push(State { + key: map_key, + node: *node, + }); + } + } + let mut out_list: Vec = Vec::with_capacity(node_count); + let dir = petgraph::Direction::Outgoing; + while let Some(State { node, .. }) = zero_indegree.pop() { + let neighbors = dag.graph.neighbors_directed(node, dir); + for child in neighbors { + let child_degree = in_degree_map.get_mut(&child).unwrap(); + *child_degree -= 1; + if *child_degree == 0 { + let map_key_raw = key_callable(&dag.graph[child])?; + let map_key: String = map_key_raw.extract(py)?; + zero_indegree.push(State { + key: map_key, + node: child, + }); + in_degree_map.remove(&child); + } + } + out_list.push(node) + } + Ok(out_list) +} diff --git a/src/dag_algo/mod.rs b/src/dag_algo/mod.rs index 7a21d5b92..a68c0afdf 100644 --- a/src/dag_algo/mod.rs +++ b/src/dag_algo/mod.rs @@ -10,11 +10,11 @@ // License for the specific language governing permissions and limitations // under the License. +mod lexicographical_sort; mod longest_path; +mod multiblock; use hashbrown::{HashMap, HashSet}; -use std::cmp::Ordering; -use std::collections::BinaryHeap; use super::iterators::NodeIndices; use crate::{digraph, DAGHasCycle, InvalidNode}; @@ -341,7 +341,7 @@ pub fn layers( /// order used. /// /// :param PyDiGraph dag: The DAG to get the topological sorted nodes from -/// :param callable key: key is a python function or other callable that +/// :param Callable key: key is a python function or other callable that /// gets passed a single argument the node data from the graph and is /// expected to return a string which will be used for resolving ties /// in the sorting order. @@ -355,75 +355,47 @@ pub fn lexicographical_topological_sort( dag: &digraph::PyDiGraph, key: PyObject, ) -> PyResult { - let key_callable = |a: &PyObject| -> PyResult { - let res = key.call1(py, (a,))?; - Ok(res.to_object(py)) - }; - // HashMap of node_index indegree - let node_count = dag.node_count(); - let mut in_degree_map: HashMap = HashMap::with_capacity(node_count); - for node in dag.graph.node_indices() { - in_degree_map.insert(node, dag.in_degree(node.index())); - } - - #[derive(Clone, Eq, PartialEq)] - struct State { - key: String, - node: NodeIndex, - } - - impl Ord for State { - fn cmp(&self, other: &State) -> Ordering { - // Notice that the we flip the ordering on costs. - // In case of a tie we compare positions - this step is necessary - // to make implementations of `PartialEq` and `Ord` consistent. - other - .key - .cmp(&self.key) - .then_with(|| other.node.index().cmp(&self.node.index())) - } - } - - // `PartialOrd` needs to be implemented as well. - impl PartialOrd for State { - fn partial_cmp(&self, other: &State) -> Option { - Some(self.cmp(other)) - } - } - let mut zero_indegree = BinaryHeap::with_capacity(node_count); - for (node, degree) in in_degree_map.iter() { - if *degree == 0 { - let map_key_raw = key_callable(&dag.graph[*node])?; - let map_key: String = map_key_raw.extract(py)?; - zero_indegree.push(State { - key: map_key, - node: *node, - }); - } - } - let mut out_list: Vec<&PyObject> = Vec::with_capacity(node_count); - let dir = petgraph::Direction::Outgoing; - while let Some(State { node, .. }) = zero_indegree.pop() { - let neighbors = dag.graph.neighbors_directed(node, dir); - for child in neighbors { - let child_degree = in_degree_map.get_mut(&child).unwrap(); - *child_degree -= 1; - if *child_degree == 0 { - let map_key_raw = key_callable(&dag.graph[child])?; - let map_key: String = map_key_raw.extract(py)?; - zero_indegree.push(State { - key: map_key, - node: child, - }); - in_degree_map.remove(&child); - } - } - out_list.push(&dag.graph[node]) - } + let sort_list = lexicographical_sort::lexicographical_topological_sort(py, dag, key)?; + let out_list: Vec<&PyObject> = sort_list.iter().map(|node| &dag.graph[*node]).collect(); Ok(PyList::new(py, out_list).into()) } -/// Return the topological sort of node indices from the provided graph +/// Collect multi-blocks from the graph +/// +/// Multi-blocks are uninterrupted sequences of nodes that operate on the same +/// +/// :param PyDiGraph dag: The graph to find the multiblocks in +/// :param int block_size: The maximum block size to find blocks for +/// :param Callable key: key is a python function or other callable that +/// gets passed a single argument the node data from the graph and is +/// expected to return a string which will be used for sorting. This is +/// the same as the ``key`` parameter from +/// :func:`~retworkx.lexicographical_topological_sort` +/// :param Callable group_fn: A callback function that will receive the node's +/// data payload as the algorithm runs and it should return a ``set`` of +/// groups that the node operates on +/// :param Callable filter_fn: A filter function that will receieve a node's +/// data payload and is expected to either return ``None``, ``True``, or +/// ``False``. If ``None`` the node will be skipped, if ``True`` the +/// node will be processed by the algorithm, and if ``False`` the node +/// will be considered unprocessesable by the algorithm but will update +/// it's internal state around this. +/// :returns: a list of lists of node indices for each multiblock found in the +/// graph +/// :rtype: list +#[pyfunction] +pub fn collect_multi_blocks( + py: Python, + dag: &digraph::PyDiGraph, + block_size: usize, + key: PyObject, + group_fn: PyObject, + filter_fn: PyObject, +) -> PyResult>> { + multiblock::collect_multi_blocks(py, dag, block_size, key, group_fn, filter_fn) +} + +/// Return the topological sort of node indexes from the provided graph /// /// :param PyDiGraph graph: The DAG to get the topological sort on /// diff --git a/src/dag_algo/multiblock.rs b/src/dag_algo/multiblock.rs new file mode 100644 index 000000000..1e2bc872d --- /dev/null +++ b/src/dag_algo/multiblock.rs @@ -0,0 +1,230 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +use std::mem; + +use hashbrown::{HashMap, HashSet}; +use rayon::prelude::*; + +use pyo3::prelude::*; +use pyo3::Python; + +use super::lexicographical_sort::lexicographical_topological_sort; +use crate::digraph; + +/// DSU function for finding root of set of items +/// If my parent is myself, I am the root. Otherwise we recursively +/// find the root for my parent. After that, we assign my parent to be +/// my root, saving recursion in the future. +fn find_set( + index: usize, + parent: &mut HashMap, + groups: &mut HashMap>, + op_groups: &mut HashMap>, +) -> usize { + let mut update_index: Vec = Vec::new(); + let mut index_iter = index; + while parent.get(&index_iter) != Some(&index_iter) { + if parent.get(&index_iter).is_none() { + parent.insert(index_iter, index_iter); + groups.insert(index_iter, vec![index_iter]); + op_groups.insert(index_iter, Vec::new()); + } + if parent.get(&index_iter) != Some(&index_iter) { + update_index.push(index_iter); + } + index_iter = parent[&index_iter]; + } + for index in update_index { + parent.insert(index, index_iter); + } + parent[&index_iter] +} + +fn combine_sets(op_groups: &mut HashMap>, set_a: usize, set_b: usize) { + let mut other_groups = op_groups.get_mut(&set_b).unwrap().clone(); + op_groups.get_mut(&set_a).unwrap().append(&mut other_groups); + op_groups.get_mut(&set_b).unwrap().clear(); +} + +/// DSU function for unioning two sets together +/// Find the roots of each set. Then assign one to have the other +/// as its parent, thus liking the sets. +/// Merges smaller set into larger set in order to have better runtime +fn union_set( + set_a_ind: usize, + set_b_ind: usize, + parent: &mut HashMap, + groups: &mut HashMap>, + op_groups: &mut HashMap>, +) { + let mut set_a = find_set(set_a_ind, parent, groups, op_groups); + let mut set_b = find_set(set_b_ind, parent, groups, op_groups); + + if set_a == set_b { + return; + } + if op_groups[&set_a].len() < op_groups[&set_b].len() { + mem::swap(&mut set_a, &mut set_b); + } + parent.insert(set_b, set_a); + combine_sets(op_groups, set_a, set_b); + combine_sets(groups, set_a, set_b) +} + +fn update_set( + group_index: usize, + parent: &mut HashMap, + groups: &mut HashMap>, + op_groups: &mut HashMap>, + block_list: &mut Vec>, +) { + if !op_groups[&group_index].is_empty() { + block_list.push(op_groups[&group_index].to_vec()); + } + let cur_set: HashSet = groups[&group_index].iter().copied().collect(); + for v in cur_set { + parent.insert(v, v); + groups.insert(v, vec![v]); + op_groups.insert(v, Vec::new()); + } +} + +pub fn collect_multi_blocks( + py: Python, + dag: &digraph::PyDiGraph, + block_size: usize, + key: PyObject, + group_fn: PyObject, + filter_fn: PyObject, +) -> PyResult>> { + let mut block_list: Vec> = Vec::new(); + + let mut parent: HashMap = HashMap::new(); + let mut groups: HashMap> = HashMap::new(); + let mut op_groups: HashMap> = HashMap::new(); + + // let sort_nodes = lexicographical_topological_sort(py, dag, key)?; + for node in lexicographical_topological_sort(py, dag, key)? { + let filter_res = filter_fn.call1(py, (&dag.graph[node],))?; + let can_process_option: Option = filter_res.extract(py)?; + if can_process_option.is_none() { + continue; + } + let raw_cur_nodes = group_fn.call1(py, (&dag.graph[node],))?; + let cur_groups: HashSet = raw_cur_nodes.extract(py)?; + let can_process: bool = can_process_option.unwrap(); + let mut makes_too_big: bool = false; + + if can_process { + let mut tops: HashSet = HashSet::new(); + for group in &cur_groups { + tops.insert(find_set(*group, &mut parent, &mut groups, &mut op_groups)); + } + let mut tot_size = 0; + for group in tops { + tot_size += groups[&group].len(); + } + if tot_size > block_size { + makes_too_big = true; + } + } + if !can_process { + // resolve the case where we cannot process this node + for group_entry in &cur_groups { + let group = find_set(*group_entry, &mut parent, &mut groups, &mut op_groups); + if op_groups[&group].is_empty() { + continue; + } + update_set( + group, + &mut parent, + &mut groups, + &mut op_groups, + &mut block_list, + ); + } + } + if makes_too_big { + // Adding in all of the new groups would make the group too big + // we must block off sub portions of the groups until the new + // group would no longer be too big + let mut savings: HashMap = HashMap::new(); + let mut tot_size = 0; + for group in &cur_groups { + let top = find_set(*group, &mut parent, &mut groups, &mut op_groups); + if !savings.contains_key(&top) { + savings.insert(top, groups[&top].len() - 1); + tot_size += groups[&top].len(); + } else { + *savings.get_mut(&top).unwrap() -= 1; + } + } + let mut savings_list: Vec<(usize, usize)> = savings + .into_iter() + .map(|(item, value)| (value, item)) + .collect(); + savings_list.par_sort_unstable(); + savings_list.reverse(); + let mut savings_need = tot_size - block_size; + for item in savings_list { + // remove groups until the size created would be acceptable + // start with blocking out the group that would decrease + // the new size the most. This heuristic for which blocks we + // create does not necessarily give the optimal blocking. + // Other heuristics may be worth considering + if savings_need > 0 { + savings_need -= item.0; + let item_index = item.1; + update_set( + item_index, + &mut parent, + &mut groups, + &mut op_groups, + &mut block_list, + ); + } + } + } + if can_process { + // if the operation is a processable, either skip it if it is too + // large group up all of the qubits involved in the gate + if cur_groups.len() > block_size { + // nodes operating on more groups than block_size cannot be a + // part of any block and thus we skip them here/ + // We have already finalized the blocks involving the node's + // groups in the above maxkes_to_big block + continue; // unable to be part of a group + } + let mut prev: Option = None; + for group in cur_groups { + let index = group; + if let Some(value) = prev { + union_set(value, index, &mut parent, &mut groups, &mut op_groups); + } + prev = Some(index); + } + if let Some(value) = prev { + let found_set = find_set(value, &mut parent, &mut groups, &mut op_groups); + op_groups.get_mut(&found_set).unwrap().push(node.index()); + } + } + } + + for (index, parent) in parent.iter() { + let parent_index = parent; + if parent_index == index && !op_groups[parent].is_empty() { + block_list.push(op_groups[index].to_vec()); + } + } + Ok(block_list) +} diff --git a/src/lib.rs b/src/lib.rs index 219e2f952..8ec16b985 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -373,6 +373,7 @@ fn rustworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(descendants))?; m.add_wrapped(wrap_pyfunction!(ancestors))?; m.add_wrapped(wrap_pyfunction!(lexicographical_topological_sort))?; + m.add_wrapped(wrap_pyfunction!(collect_multi_blocks))?; m.add_wrapped(wrap_pyfunction!(graph_floyd_warshall))?; m.add_wrapped(wrap_pyfunction!(digraph_floyd_warshall))?; m.add_wrapped(wrap_pyfunction!(graph_floyd_warshall_numpy))?; diff --git a/tests/rustworkx_tests/digraph/test_collect_multi_block.py b/tests/rustworkx_tests/digraph/test_collect_multi_block.py new file mode 100644 index 000000000..bd1bf54f1 --- /dev/null +++ b/tests/rustworkx_tests/digraph/test_collect_multi_block.py @@ -0,0 +1,186 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest +import rustworkx + + +class TestCollectMultiBlock(unittest.TestCase): + def test_blocks_2q_blocks(self): + graph = rustworkx.PyDiGraph() + q0 = graph.add_node({"type": "in", "name": "q0", "groups": []}) + q1 = graph.add_node({"type": "in", "name": "q1", "groups": []}) + q2 = graph.add_node({"type": "in", "name": "q2", "groups": []}) + u1 = graph.add_node({"type": "op", "name": "u1", "groups": [0]}) + u2 = graph.add_node({"type": "op", "name": "u2", "groups": [1]}) + cx_1 = graph.add_node({"type": "op", "name": "cx", "groups": [2, 1]}) + cx_2 = graph.add_node({"type": "op", "name": "cx", "groups": [0, 1]}) + q0_out = graph.add_node({"type": "out", "name": "q0", "groups": []}) + q1_out = graph.add_node({"type": "out", "name": "q1", "groups": []}) + q2_out = graph.add_node({"type": "out", "name": "q2", "groups": []}) + graph.add_edges_from_no_data( + [ + (q0, u1), + (q1, u2), + (q2, cx_1), + (u2, cx_1), + (u1, cx_2), + (cx_1, cx_2), + (cx_2, q2_out), + (cx_1, q1_out), + (cx_2, q0_out), + ] + ) + + def group_fn(node): + return set(node["groups"]) + + def key_fn(node): + if node["type"] == "in": + return "a" + if node["type"] != "op": + return "d" + return "b" + chr(ord("a") + len(node["groups"])) + + def filter_fn(node): + if node["type"] != "op": + return None + return True + + blocks = rustworkx.collect_multi_blocks(graph, 2, key_fn, group_fn, filter_fn) + self.assertEqual(blocks, [[4, 5], [3, 6]]) + + def test_blocks_unprocessed(self): + graph = rustworkx.PyDiGraph() + q0 = graph.add_node({"type": "in", "name": "q0", "groups": []}) + q1 = graph.add_node({"type": "in", "name": "q1", "groups": []}) + q2 = graph.add_node({"type": "in", "name": "q2", "groups": []}) + c0 = graph.add_node({"type": "in", "name": "c0", "groups": []}) + cx_1 = graph.add_node({"type": "op", "name": "cx", "groups": [0, 1]}) + cx_2 = graph.add_node({"type": "op", "name": "cx", "groups": [1, 2]}) + measure = graph.add_node({"type": "op", "name": "measure", "groups": [0]}) + cx_3 = graph.add_node({"type": "op", "name": "cx", "groups": [1, 2]}) + x = graph.add_node({"type": "op", "name": "x", "groups": [1]}) + h = graph.add_node({"type": "op", "name": "h", "groups": [2]}) + q0_out = graph.add_node({"type": "out", "name": "q0", "groups": []}) + q1_out = graph.add_node({"type": "out", "name": "q1", "groups": []}) + q2_out = graph.add_node({"type": "out", "name": "q2", "groups": []}) + c0_out = graph.add_node({"type": "out", "name": "c0", "groups": []}) + graph.add_edges_from_no_data( + [ + (q0, cx_1), + (q1, cx_1), + (cx_1, cx_2), + (q2, cx_2), + (cx_1, measure), + (c0, measure), + (cx_2, cx_3), + (cx_2, cx_3), + (cx_3, x), + (cx_3, h), + (measure, q0_out), + (measure, c0_out), + (x, q1_out), + (h, q2_out), + ] + ) + + def group_fn(node): + return set(node["groups"]) + + def key_fn(node): + if node["type"] == "in": + return "a" + if node["type"] != "op": + return "d" + if node["name"] == "measure": + return "d" + return "b" + chr(ord("a") + len(node["groups"])) + + def filter_fn(node): + if node["type"] != "op": + return None + if node["name"] == "measure": + return False + return True + + blocks = rustworkx.collect_multi_blocks(graph, 2, key_fn, group_fn, filter_fn) + self.assertEqual(blocks, [[4], [5, 7, 8, 9]]) + + def test_empty_graph(self): + graph = rustworkx.PyDiGraph() + block = rustworkx.collect_multi_blocks(graph, 1, lambda x: x, lambda x: x, lambda x: x) + self.assertEqual(block, []) + + def test_larger_block(self): + graph = rustworkx.PyDiGraph() + q0 = graph.add_node({"type": "in", "name": "q0", "groups": []}) + q1 = graph.add_node({"type": "in", "name": "q1", "groups": []}) + q2 = graph.add_node({"type": "in", "name": "q2", "groups": []}) + q3 = graph.add_node({"type": "in", "name": "q3", "groups": []}) + q4 = graph.add_node({"type": "in", "name": "q4", "groups": []}) + cx_1 = graph.add_node({"type": "op", "name": "cx", "groups": [0, 1]}) + cx_2 = graph.add_node({"type": "op", "name": "cx", "groups": [1, 2]}) + cx_3 = graph.add_node({"type": "op", "name": "cx", "groups": [2, 3]}) + ccx = graph.add_node({"type": "op", "name": "ccx", "groups": [0, 1, 2]}) + cx_4 = graph.add_node({"type": "op", "name": "cx", "groups": [3, 4]}) + cx_5 = graph.add_node({"type": "op", "name": "cx", "groups": [3, 4]}) + q0_out = graph.add_node({"type": "out", "name": "q0", "groups": []}) + q1_out = graph.add_node({"type": "out", "name": "q1", "groups": []}) + q2_out = graph.add_node({"type": "out", "name": "q2", "groups": []}) + q3_out = graph.add_node({"type": "out", "name": "q3", "groups": []}) + q4_out = graph.add_node({"type": "out", "name": "q4", "groups": []}) + + graph.add_edges_from( + [ + (q0, cx_1, "q0"), + (q1, cx_1, "q1"), + (cx_1, cx_2, "q1"), + (q2, cx_2, "q2"), + (cx_2, cx_3, "q2"), + (q3, cx_3, "q3"), + (cx_1, ccx, "q0"), + (cx_2, ccx, "q1"), + (cx_3, ccx, "q2"), + (cx_3, cx_4, "q3"), + (q4, cx_4, "q4"), + (cx_4, cx_5, "q3"), + (cx_4, cx_5, "q4"), + (ccx, q0_out, "q0"), + (ccx, q1_out, "q1"), + (ccx, q2_out, "q2"), + (cx_5, q3_out, "q3"), + (cx_5, q4_out, "q4"), + ] + ) + + def group_fn(node): + return set(node["groups"]) + + def key_fn(node): + if node["type"] == "in": + return "a" + if node["type"] != "op": + return "d" + if node["name"] == "measure": + return "d" + return "b" + chr(ord("a") + len(node["groups"])) + + def filter_fn(node): + if node["type"] != "op": + return None + if node["name"] == "measure": + return False + return True + + blocks = rustworkx.collect_multi_blocks(graph, 4, key_fn, group_fn, filter_fn) + self.assertEqual([[cx_1, cx_2, cx_3], [ccx], [cx_4, cx_5]], blocks)