Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add collect_multi_blocks function #461

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
97 changes: 97 additions & 0 deletions src/dag_algo/lexicographical_sort.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.

use std::cmp::Ordering;
use std::collections::BinaryHeap;

use hashbrown::HashMap;

use pyo3::prelude::*;
use pyo3::Python;

use petgraph::graph::NodeIndex;
use petgraph::visit::NodeCount;

use crate::digraph;

pub fn lexicographical_topological_sort(
py: Python,
dag: &digraph::PyDiGraph,
key: PyObject,
) -> PyResult<Vec<NodeIndex>> {
let key_callable = |a: &PyObject| -> PyResult<PyObject> {
let res = key.call1(py, (a,))?;
Ok(res.to_object(py))
};
// HashMap of node_index indegree
let node_count = dag.node_count();
let mut in_degree_map: HashMap<NodeIndex, usize> = HashMap::with_capacity(node_count);
for node in dag.graph.node_indices() {
in_degree_map.insert(node, dag.in_degree(node.index()));
}

#[derive(Clone, Eq, PartialEq)]
struct State {
key: String,
node: NodeIndex,
}

impl Ord for State {
fn cmp(&self, other: &State) -> Ordering {
// Notice that the we flip the ordering on costs.
// In case of a tie we compare positions - this step is necessary
// to make implementations of `PartialEq` and `Ord` consistent.
other
.key
.cmp(&self.key)
.then_with(|| other.node.index().cmp(&self.node.index()))
}
}

// `PartialOrd` needs to be implemented as well.
impl PartialOrd for State {
fn partial_cmp(&self, other: &State) -> Option<Ordering> {
Some(self.cmp(other))
}
}
let mut zero_indegree = BinaryHeap::with_capacity(node_count);
for (node, degree) in in_degree_map.iter() {
if *degree == 0 {
let map_key_raw = key_callable(&dag.graph[*node])?;
let map_key: String = map_key_raw.extract(py)?;
zero_indegree.push(State {
key: map_key,
node: *node,
});
}
}
let mut out_list: Vec<NodeIndex> = Vec::with_capacity(node_count);
let dir = petgraph::Direction::Outgoing;
while let Some(State { node, .. }) = zero_indegree.pop() {
let neighbors = dag.graph.neighbors_directed(node, dir);
for child in neighbors {
let child_degree = in_degree_map.get_mut(&child).unwrap();
*child_degree -= 1;
if *child_degree == 0 {
let map_key_raw = key_callable(&dag.graph[child])?;
let map_key: String = map_key_raw.extract(py)?;
zero_indegree.push(State {
key: map_key,
node: child,
});
in_degree_map.remove(&child);
}
}
out_list.push(node)
}
Ok(out_list)
}
110 changes: 41 additions & 69 deletions src/dag_algo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
// License for the specific language governing permissions and limitations
// under the License.

mod lexicographical_sort;
mod longest_path;
mod multiblock;

use hashbrown::{HashMap, HashSet};
use std::cmp::Ordering;
use std::collections::BinaryHeap;

use super::iterators::NodeIndices;
use crate::{digraph, DAGHasCycle, InvalidNode};
Expand Down Expand Up @@ -341,7 +341,7 @@ pub fn layers(
/// order used.
///
/// :param PyDiGraph dag: The DAG to get the topological sorted nodes from
/// :param callable key: key is a python function or other callable that
/// :param Callable key: key is a python function or other callable that
/// gets passed a single argument the node data from the graph and is
/// expected to return a string which will be used for resolving ties
/// in the sorting order.
Expand All @@ -355,75 +355,47 @@ pub fn lexicographical_topological_sort(
dag: &digraph::PyDiGraph,
key: PyObject,
) -> PyResult<PyObject> {
let key_callable = |a: &PyObject| -> PyResult<PyObject> {
let res = key.call1(py, (a,))?;
Ok(res.to_object(py))
};
// HashMap of node_index indegree
let node_count = dag.node_count();
let mut in_degree_map: HashMap<NodeIndex, usize> = HashMap::with_capacity(node_count);
for node in dag.graph.node_indices() {
in_degree_map.insert(node, dag.in_degree(node.index()));
}

#[derive(Clone, Eq, PartialEq)]
struct State {
key: String,
node: NodeIndex,
}

impl Ord for State {
fn cmp(&self, other: &State) -> Ordering {
// Notice that the we flip the ordering on costs.
// In case of a tie we compare positions - this step is necessary
// to make implementations of `PartialEq` and `Ord` consistent.
other
.key
.cmp(&self.key)
.then_with(|| other.node.index().cmp(&self.node.index()))
}
}

// `PartialOrd` needs to be implemented as well.
impl PartialOrd for State {
fn partial_cmp(&self, other: &State) -> Option<Ordering> {
Some(self.cmp(other))
}
}
let mut zero_indegree = BinaryHeap::with_capacity(node_count);
for (node, degree) in in_degree_map.iter() {
if *degree == 0 {
let map_key_raw = key_callable(&dag.graph[*node])?;
let map_key: String = map_key_raw.extract(py)?;
zero_indegree.push(State {
key: map_key,
node: *node,
});
}
}
let mut out_list: Vec<&PyObject> = Vec::with_capacity(node_count);
let dir = petgraph::Direction::Outgoing;
while let Some(State { node, .. }) = zero_indegree.pop() {
let neighbors = dag.graph.neighbors_directed(node, dir);
for child in neighbors {
let child_degree = in_degree_map.get_mut(&child).unwrap();
*child_degree -= 1;
if *child_degree == 0 {
let map_key_raw = key_callable(&dag.graph[child])?;
let map_key: String = map_key_raw.extract(py)?;
zero_indegree.push(State {
key: map_key,
node: child,
});
in_degree_map.remove(&child);
}
}
out_list.push(&dag.graph[node])
}
let sort_list = lexicographical_sort::lexicographical_topological_sort(py, dag, key)?;
let out_list: Vec<&PyObject> = sort_list.iter().map(|node| &dag.graph[*node]).collect();
Ok(PyList::new(py, out_list).into())
}

/// Return the topological sort of node indices from the provided graph
/// Collect multi-blocks from the graph
///
/// Multi-blocks are uninterrupted sequences of nodes that operate on the same
///
/// :param PyDiGraph dag: The graph to find the multiblocks in
/// :param int block_size: The maximum block size to find blocks for
/// :param Callable key: key is a python function or other callable that
/// gets passed a single argument the node data from the graph and is
/// expected to return a string which will be used for sorting. This is
/// the same as the ``key`` parameter from
/// :func:`~retworkx.lexicographical_topological_sort`
/// :param Callable group_fn: A callback function that will receive the node's
/// data payload as the algorithm runs and it should return a ``set`` of
/// groups that the node operates on
/// :param Callable filter_fn: A filter function that will receieve a node's
/// data payload and is expected to either return ``None``, ``True``, or
/// ``False``. If ``None`` the node will be skipped, if ``True`` the
/// node will be processed by the algorithm, and if ``False`` the node
/// will be considered unprocessesable by the algorithm but will update
/// it's internal state around this.
/// :returns: a list of lists of node indices for each multiblock found in the
/// graph
/// :rtype: list
#[pyfunction]
pub fn collect_multi_blocks(
py: Python,
dag: &digraph::PyDiGraph,
block_size: usize,
key: PyObject,
group_fn: PyObject,
filter_fn: PyObject,
) -> PyResult<Vec<Vec<usize>>> {
multiblock::collect_multi_blocks(py, dag, block_size, key, group_fn, filter_fn)
}

/// Return the topological sort of node indexes from the provided graph
///
/// :param PyDiGraph graph: The DAG to get the topological sort on
///
Expand Down
Loading