Skip to content

Commit

Permalink
Handle struct memlets in normalization, structure views as views
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun committed Sep 25, 2023
1 parent d2c4370 commit 2943429
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def dealias_sdfg(sdfg: SDFG):
parent_arr.lifetime, parent_arr.alignment, parent_arr.debuginfo,
parent_arr.total_size, parent_arr.start_offset, parent_arr.optional,
parent_arr.pool)
elif isinstance(parent_arr, data.StructureView):
parent_arr = data.Structure(parent_arr.members, parent_arr.name, parent_arr.transient,
parent_arr.storage, parent_arr.location, parent_arr.lifetime,
parent_arr.debuginfo)
child_names = inv_replacements[parent_name]
for name in child_names:
child_arr = copy.deepcopy(parent_arr)
Expand Down Expand Up @@ -158,6 +162,8 @@ def normalize_memlet(sdfg: SDFG, state: SDFGState, original: gr.MultiConnectorEd
copy.deepcopy(original.data), original.key)
edge.data.try_initialize(sdfg, state, edge)

if '.' in edge.data.data and edge.data.data.startswith(data + '.'):
return edge.data
if edge.data.data == data:
return edge.data

Expand Down Expand Up @@ -379,7 +385,7 @@ def prepare_schedule_tree_edges(state: SDFGState) -> Dict[gr.MultiConnectorEdge[
# 1. Check for views
if isinstance(e.src, dace.nodes.AccessNode):
desc = e.src.desc(sdfg)
if isinstance(desc, dace.data.View):
if isinstance(desc, (dace.data.View, dace.data.StructureView)):
vedge = sdutil.get_view_edge(state, e.src)
if e is vedge:
viewed_node = sdutil.get_view_node(state, e.src)
Expand All @@ -389,7 +395,7 @@ def prepare_schedule_tree_edges(state: SDFGState) -> Dict[gr.MultiConnectorEdge[
continue
if isinstance(e.dst, dace.nodes.AccessNode):
desc = e.dst.desc(sdfg)
if isinstance(desc, dace.data.View):
if isinstance(desc, (dace.data.View, dace.data.StructureView)):
vedge = sdutil.get_view_edge(state, e.dst)
if e is vedge:
viewed_node = sdutil.get_view_node(state, e.dst)
Expand Down

0 comments on commit 2943429

Please sign in to comment.