Skip to content

Commit

Permalink
Updated and fixed the MapDimShuffle tranformation. (#1531)
Browse files Browse the repository at this point in the history
The reordering of the dimension was quite strange, it is now much more
cleaner. Furthermore, the `tile_sizes` parameter of the map range was
ignored, i.e. not changed. Furthermore, the parameters are now correclty
copied. Before it was just assigned, thus if later the parameter list of
the map was again changed this effect would propagate to _all_ maps that
where treated.
  • Loading branch information
philip-paul-mueller authored Feb 27, 2024
1 parent ba1587e commit b1a7f8a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
23 changes: 12 additions & 11 deletions dace/transformation/dataflow/map_dim_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,19 @@ def expressions(cls):
return [sdutil.node_path_graph(cls.map_entry)]

def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
map_entry: nodes.MapEntry = self.map_entry
if self.parameters is None:
return False
if len(self.parameters) != len(map_entry.map.params):
return False
if set(self.parameters) != set(map_entry.map.params):
return False
return True

def apply(self, graph: SDFGState, sdfg: SDFG):
map_entry = self.map_entry
if self.parameters is None:
return

if set(self.parameters) != set(map_entry.map.params):
return
map_entry: nodes.MapEntry = self.map_entry
new_map_order: list[int] = [map_entry.map.params.index(param) for param in self.parameters]

map_entry.range.ranges = [
r for list_param in self.parameters for map_param, r in zip(map_entry.map.params, map_entry.range.ranges)
if list_param == map_param
]
map_entry.map.params = self.parameters
map_entry.range.ranges = [map_entry.range.ranges[new_pos] for new_pos in new_map_order]
map_entry.range.tile_sizes = [map_entry.range.tile_sizes[new_pos] for new_pos in new_map_order]
map_entry.map.params = [map_entry.map.params[new_pos] for new_pos in new_map_order]
3 changes: 3 additions & 0 deletions tests/transformations/map_dim_shuffle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def test_map_dim_shuffle():
sdfg(A=A, B=B)
assert np.allclose(B, expected)

assert sdfg.apply_transformations_repeated(MapDimShuffle, options={"parameters": ["k", "i"]}) == 0
assert sdfg.apply_transformations_repeated(MapDimShuffle, options={"parameters": ["k", "i", "l"]}) == 0


if __name__ == '__main__':
test_map_dim_shuffle()

0 comments on commit b1a7f8a

Please sign in to comment.