diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index b33568e864..b5a80597c0 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -68,6 +68,10 @@ def dealias_sdfg(sdfg: SDFG): parent_arr.lifetime, parent_arr.alignment, parent_arr.debuginfo, parent_arr.total_size, parent_arr.start_offset, parent_arr.optional, parent_arr.pool) + elif isinstance(parent_arr, data.StructureView): + parent_arr = data.Structure(parent_arr.members, parent_arr.name, parent_arr.transient, + parent_arr.storage, parent_arr.location, parent_arr.lifetime, + parent_arr.debuginfo) child_names = inv_replacements[parent_name] for name in child_names: child_arr = copy.deepcopy(parent_arr) @@ -158,6 +162,8 @@ def normalize_memlet(sdfg: SDFG, state: SDFGState, original: gr.MultiConnectorEd copy.deepcopy(original.data), original.key) edge.data.try_initialize(sdfg, state, edge) + if '.' in edge.data.data and edge.data.data.startswith(data + '.'): + return edge.data if edge.data.data == data: return edge.data @@ -379,7 +385,7 @@ def prepare_schedule_tree_edges(state: SDFGState) -> Dict[gr.MultiConnectorEdge[ # 1. Check for views if isinstance(e.src, dace.nodes.AccessNode): desc = e.src.desc(sdfg) - if isinstance(desc, dace.data.View): + if isinstance(desc, (dace.data.View, dace.data.StructureView)): vedge = sdutil.get_view_edge(state, e.src) if e is vedge: viewed_node = sdutil.get_view_node(state, e.src) @@ -389,7 +395,7 @@ def prepare_schedule_tree_edges(state: SDFGState) -> Dict[gr.MultiConnectorEdge[ continue if isinstance(e.dst, dace.nodes.AccessNode): desc = e.dst.desc(sdfg) - if isinstance(desc, dace.data.View): + if isinstance(desc, (dace.data.View, dace.data.StructureView)): vedge = sdutil.get_view_edge(state, e.dst) if e is vedge: viewed_node = sdutil.get_view_node(state, e.dst)