From fdce53a6b818d711b43dc74543cfd0be6d5cbb6e Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Fri, 13 Dec 2024 12:18:54 +0100 Subject: [PATCH] Ensure array elimination correctly merges view nodes --- .../passes/array_elimination.py | 53 +++++++++++++++---- 1 file changed, 44 insertions(+), 9 deletions(-) diff --git a/dace/transformation/passes/array_elimination.py b/dace/transformation/passes/array_elimination.py index fd472336e0..e8cb44b46d 100644 --- a/dace/transformation/passes/array_elimination.py +++ b/dace/transformation/passes/array_elimination.py @@ -1,10 +1,13 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. from collections import defaultdict from typing import Any, Callable, Dict, List, Optional, Set from dace import SDFG, SDFGState, data, properties +from dace.memlet import Memlet from dace.sdfg import nodes from dace.sdfg.analysis import cfg +from dace.sdfg.graph import MultiConnectorEdge +from dace.sdfg.validation import InvalidSDFGNodeError from dace.transformation import pass_pipeline as ppl, transformation from dace.transformation.dataflow import (RedundantArray, RedundantReadSlice, RedundantSecondArray, RedundantWriteSlice, SqueezeViewRemove, UnsqueezeViewRemove, RemoveSliceView) @@ -107,20 +110,52 @@ def merge_access_nodes(self, state: SDFGState, access_nodes: Dict[str, List[node removed_nodes: Set[nodes.AccessNode] = set() for nodeset in access_nodes.values(): if len(nodeset) > 1: - # Merge all other access nodes to the first one - first_node = nodeset[0] - if not condition(first_node): + # Merge all other access nodes to the first one that fits the condition, if one exists. + first_node = None + first_node_idx = 0 + for i, node in enumerate(nodeset[:-1]): + if condition(node): + first_node = node + first_node_idx = i + break + if first_node is None: continue - for node in nodeset[1:]: + + for node in nodeset[first_node_idx + 1:]: if not condition(node): continue - # Reconnect edges to first node - for edge in state.all_edges(node): + # Reconnect edges to first node. + # If we are handling views, we do not want to add more than one edge going into a 'views' connector, + # so we only merge nodes if the memlets match exactly (which they should). But in that case without + # copying the edge. + edges: List[MultiConnectorEdge[Memlet]] = state.all_edges(node) + other_edges: List[MultiConnectorEdge[Memlet]] = [] + for edge in edges: if edge.dst is node: - state.add_edge(edge.src, edge.src_conn, first_node, edge.dst_conn, edge.data) + if edge.dst_conn == 'views': + other_edges = list(state.in_edges_by_connector(first_node, 'views')) + if len(other_edges) != 1: + raise InvalidSDFGNodeError('Multiple edges connected to views connector', + state.sdfg, state.block_id, state.node_id(first_node)) + other_view_edge = other_edges[0] + if other_view_edge.data != edge.data: + # The memlets do not match, skip the node. + continue + else: + state.add_edge(edge.src, edge.src_conn, first_node, edge.dst_conn, edge.data) else: - state.add_edge(first_node, edge.src_conn, edge.dst, edge.dst_conn, edge.data) + if edge.src_conn == 'views': + other_edges = list(state.out_edges_by_connector(first_node, 'views')) + if len(other_edges) != 1: + raise InvalidSDFGNodeError('Multiple edges connected to views connector', + state.sdfg, state.block_id, state.node_id(first_node)) + other_view_edge = other_edges[0] + if other_view_edge.data != edge.data: + # The memlets do not match, skip the node. + continue + else: + state.add_edge(first_node, edge.src_conn, edge.dst, edge.dst_conn, edge.data) # Remove merged node and associated edges state.remove_node(node) removed_nodes.add(node)