Skip to content

Commit

Permalink
Ensure array elimination correctly merges view nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Dec 13, 2024
1 parent 200e606 commit fdce53a
Showing 1 changed file with 44 additions and 9 deletions.
53 changes: 44 additions & 9 deletions dace/transformation/passes/array_elimination.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Set

from dace import SDFG, SDFGState, data, properties
from dace.memlet import Memlet
from dace.sdfg import nodes
from dace.sdfg.analysis import cfg
from dace.sdfg.graph import MultiConnectorEdge
from dace.sdfg.validation import InvalidSDFGNodeError
from dace.transformation import pass_pipeline as ppl, transformation
from dace.transformation.dataflow import (RedundantArray, RedundantReadSlice, RedundantSecondArray, RedundantWriteSlice,
SqueezeViewRemove, UnsqueezeViewRemove, RemoveSliceView)
Expand Down Expand Up @@ -107,20 +110,52 @@ def merge_access_nodes(self, state: SDFGState, access_nodes: Dict[str, List[node
removed_nodes: Set[nodes.AccessNode] = set()
for nodeset in access_nodes.values():
if len(nodeset) > 1:
# Merge all other access nodes to the first one
first_node = nodeset[0]
if not condition(first_node):
# Merge all other access nodes to the first one that fits the condition, if one exists.
first_node = None
first_node_idx = 0
for i, node in enumerate(nodeset[:-1]):
if condition(node):
first_node = node
first_node_idx = i
break
if first_node is None:
continue
for node in nodeset[1:]:

for node in nodeset[first_node_idx + 1:]:
if not condition(node):
continue

# Reconnect edges to first node
for edge in state.all_edges(node):
# Reconnect edges to first node.
# If we are handling views, we do not want to add more than one edge going into a 'views' connector,
# so we only merge nodes if the memlets match exactly (which they should). But in that case without
# copying the edge.
edges: List[MultiConnectorEdge[Memlet]] = state.all_edges(node)
other_edges: List[MultiConnectorEdge[Memlet]] = []
for edge in edges:
if edge.dst is node:
state.add_edge(edge.src, edge.src_conn, first_node, edge.dst_conn, edge.data)
if edge.dst_conn == 'views':
other_edges = list(state.in_edges_by_connector(first_node, 'views'))
if len(other_edges) != 1:
raise InvalidSDFGNodeError('Multiple edges connected to views connector',
state.sdfg, state.block_id, state.node_id(first_node))
other_view_edge = other_edges[0]
if other_view_edge.data != edge.data:
# The memlets do not match, skip the node.
continue
else:
state.add_edge(edge.src, edge.src_conn, first_node, edge.dst_conn, edge.data)
else:
state.add_edge(first_node, edge.src_conn, edge.dst, edge.dst_conn, edge.data)
if edge.src_conn == 'views':
other_edges = list(state.out_edges_by_connector(first_node, 'views'))
if len(other_edges) != 1:
raise InvalidSDFGNodeError('Multiple edges connected to views connector',
state.sdfg, state.block_id, state.node_id(first_node))
other_view_edge = other_edges[0]
if other_view_edge.data != edge.data:
# The memlets do not match, skip the node.
continue
else:
state.add_edge(first_node, edge.src_conn, edge.dst, edge.dst_conn, edge.data)
# Remove merged node and associated edges
state.remove_node(node)
removed_nodes.add(node)
Expand Down

0 comments on commit fdce53a

Please sign in to comment.