From 233904582e7c28a869771a8037fbc3638fe5b44e Mon Sep 17 00:00:00 2001 From: Cory Forsstrom Date: Fri, 13 Oct 2023 13:29:03 -0700 Subject: [PATCH] Fix subgraph, connect edges properly --- dag/src/subgraph.rs | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/dag/src/subgraph.rs b/dag/src/subgraph.rs index c148b522..ef65840d 100644 --- a/dag/src/subgraph.rs +++ b/dag/src/subgraph.rs @@ -6,19 +6,28 @@ use petgraph::{prelude::Graph, stable_graph::IndexType, visit::Dfs, EdgeType}; /// Given an input [`Graph`] and the start nodes, construct a subgraph /// Used largely in transposed form for reverse dependency calculation -pub fn subgraph(graph: &Graph, starting_nodes: &[N]) -> Graph +pub fn subgraph( + graph: &Graph, + starting_nodes: &[N], +) -> Graph where N: PartialEq + Clone, - E: Default, + E: Clone, Ix: IndexType, Ty: EdgeType, { - let mut res = Graph::default(); + let add_node = |graph: &mut Graph, node| { + if let Some(index) = graph.node_indices().find(|i| graph[*i] == node) { + index + } else { + graph.add_node(node) + } + }; + let mut res = Graph::default(); let mut dfs = Dfs::empty(&graph); - for starting_node in starting_nodes { - let node_index = res.add_node(starting_node.clone()); + for starting_node in starting_nodes { let Some(starting_node_index) = graph.node_indices().find(|n| graph[*n] == *starting_node) else { continue; @@ -28,9 +37,11 @@ where while let Some(node) = dfs.next(&graph) { for neighbor in graph.neighbors_directed(node, petgraph::Direction::Outgoing) { - let neighbor_node = graph[neighbor].clone(); - let neighbor_index = res.add_node(neighbor_node); - res.add_edge(node_index, neighbor_index, E::default()); + if let Some(edge) = graph.find_edge(node, neighbor) { + let node_index = add_node(&mut res, graph[node].clone()); + let neighbor_index = add_node(&mut res, graph[neighbor].clone()); + res.add_edge(node_index, neighbor_index, graph[edge].clone()); + } } } }