Skip to content

Commit

Permalink
Use memlet trees for data access
Browse files Browse the repository at this point in the history
  • Loading branch information
ThrudPrimrose committed Oct 29, 2024
1 parent 60cf46d commit a25d096
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions dace/transformation/interstate/gpu_transform_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,17 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
return False
return True

def _output_or_input_is_marked_host(self, state, entry_node):
def _get_marked_inputs_and_outputs(self, state, entry_node) -> list:
if (self.host_data is None or self.host_data == []) and (self.host_maps is None or self.host_maps == []):
return False
marked_accesses = [e.data.data for e in state.in_edges(entry_node) + state.out_edges(state.exit_node(entry_node))
if e.data.data is not None and e.data.data in self.host_data]
return len(marked_accesses) > 0
return []
marked_sources = [state.memlet_tree(e).root().edge.src for e in state.in_edges(entry_node)]
marked_destinations = [state.memlet_tree(e).root().edge.dst for e in state.in_edges(state.exit_node(entry_node))]
marked_accesses = [n.data for n in (marked_sources + marked_destinations) if isinstance(n, nodes.AccessNode) and n.data in self.host_data]
return marked_accesses

def _output_or_input_is_marked_host(self, state, entry_node) -> bool:
marked_accesses = self._get_marked_inputs_and_outputs(state, entry_node)
return (len(marked_accesses) > 0)


def apply(self, _, sdfg: sd.SDFG):
Expand All @@ -188,8 +193,7 @@ def apply(self, _, sdfg: sd.SDFG):
for state in sdfg.nodes():
for node in state.nodes():
if isinstance(node, nodes.EntryNode) and node.guid in self.host_maps:
accesses = {e.data.data for e in state.in_edges(node) + state.out_edges(state.exit_node(node))
if e.data.data is not None and node.guid in self.host_maps}
accesses = self._get_marked_inputs_and_outputs(state, node)
self.host_data.extend(accesses)

for state in sdfg.nodes():
Expand Down Expand Up @@ -338,7 +342,7 @@ def apply(self, _, sdfg: sd.SDFG):
for node in state.nodes():
if sdict[node] is None:
if isinstance(node, (nodes.LibraryNode, nodes.NestedSDFG)):
if node.guid not in self.host_maps and not self._output_or_input_is_marked_host(state, node):
if node.guid:
node.schedule = dtypes.ScheduleType.GPU_Default
gpu_nodes.add((state, node))
elif isinstance(node, nodes.EntryNode):
Expand Down

0 comments on commit a25d096

Please sign in to comment.