From a25d0966f365152c3030d837ffc99a01dfabbb5e Mon Sep 17 00:00:00 2001 From: Yakup Budanaz Date: Tue, 29 Oct 2024 10:38:16 +0100 Subject: [PATCH] Use memlet trees for data access --- .../interstate/gpu_transform_sdfg.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/dace/transformation/interstate/gpu_transform_sdfg.py b/dace/transformation/interstate/gpu_transform_sdfg.py index 3fd533783d..b7aa3b708a 100644 --- a/dace/transformation/interstate/gpu_transform_sdfg.py +++ b/dace/transformation/interstate/gpu_transform_sdfg.py @@ -160,12 +160,17 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return False return True - def _output_or_input_is_marked_host(self, state, entry_node): + def _get_marked_inputs_and_outputs(self, state, entry_node) -> list: if (self.host_data is None or self.host_data == []) and (self.host_maps is None or self.host_maps == []): - return False - marked_accesses = [e.data.data for e in state.in_edges(entry_node) + state.out_edges(state.exit_node(entry_node)) - if e.data.data is not None and e.data.data in self.host_data] - return len(marked_accesses) > 0 + return [] + marked_sources = [state.memlet_tree(e).root().edge.src for e in state.in_edges(entry_node)] + marked_destinations = [state.memlet_tree(e).root().edge.dst for e in state.in_edges(state.exit_node(entry_node))] + marked_accesses = [n.data for n in (marked_sources + marked_destinations) if isinstance(n, nodes.AccessNode) and n.data in self.host_data] + return marked_accesses + + def _output_or_input_is_marked_host(self, state, entry_node) -> bool: + marked_accesses = self._get_marked_inputs_and_outputs(state, entry_node) + return (len(marked_accesses) > 0) def apply(self, _, sdfg: sd.SDFG): @@ -188,8 +193,7 @@ def apply(self, _, sdfg: sd.SDFG): for state in sdfg.nodes(): for node in state.nodes(): if isinstance(node, nodes.EntryNode) and node.guid in self.host_maps: - accesses = {e.data.data for e in state.in_edges(node) + state.out_edges(state.exit_node(node)) - if e.data.data is not None and node.guid in self.host_maps} + accesses = self._get_marked_inputs_and_outputs(state, node) self.host_data.extend(accesses) for state in sdfg.nodes(): @@ -338,7 +342,7 @@ def apply(self, _, sdfg: sd.SDFG): for node in state.nodes(): if sdict[node] is None: if isinstance(node, (nodes.LibraryNode, nodes.NestedSDFG)): - if node.guid not in self.host_maps and not self._output_or_input_is_marked_host(state, node): + if node.guid: node.schedule = dtypes.ScheduleType.GPU_Default gpu_nodes.add((state, node)) elif isinstance(node, nodes.EntryNode):