From bf56e4d065fc36881ad4d76f59ba7833f4724054 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 18 Dec 2023 09:13:18 -0800 Subject: [PATCH] Analysis passes for access range analysis (#1484) Adds two analysis passes to help with analyzing data access sets: access ranges and Reference sources. To enable constructing sets of memlets, this PR also reintroduces data descriptor names to memlet hashes. --- dace/memlet.py | 11 ++-- dace/transformation/passes/analysis.py | 80 +++++++++++++++++++++++++- tests/passes/access_ranges_test.py | 61 ++++++++++++++++++++ tests/sdfg/reference_test.py | 21 ++++++- 4 files changed, 163 insertions(+), 10 deletions(-) create mode 100644 tests/passes/access_ranges_test.py diff --git a/dace/memlet.py b/dace/memlet.py index d448ca1134..e7f0699eb8 100644 --- a/dace/memlet.py +++ b/dace/memlet.py @@ -169,8 +169,7 @@ def to_json(self): attrs['is_data_src'] = self._is_data_src # Fill in legacy (DEPRECATED) values for backwards compatibility - attrs['num_accesses'] = \ - str(self.volume) if not self.dynamic else -1 + attrs['num_accesses'] = str(self.volume) if not self.dynamic else -1 return {"type": "Memlet", "attributes": attrs} @@ -421,13 +420,11 @@ def from_array(dataname, datadesc, wcr=None): return Memlet.simple(dataname, rng, wcr_str=wcr) def __hash__(self): - return hash((self.volume, self.src_subset, self.dst_subset, str(self.wcr))) + return hash((self.data, self.volume, self.src_subset, self.dst_subset, str(self.wcr))) def __eq__(self, other): - return all([ - self.volume == other.volume, self.src_subset == other.src_subset, self.dst_subset == other.dst_subset, - self.wcr == other.wcr - ]) + return all((self.data == other.data, self.volume == other.volume, self.src_subset == other.src_subset, + self.dst_subset == other.dst_subset, self.wcr == other.wcr)) def replace(self, repl_dict): """ diff --git a/dace/transformation/passes/analysis.py b/dace/transformation/passes/analysis.py index 86e1cde062..d6b235a876 100644 --- a/dace/transformation/passes/analysis.py +++ b/dace/transformation/passes/analysis.py @@ -2,7 +2,7 @@ from collections import defaultdict from dace.transformation import pass_pipeline as ppl -from dace import SDFG, SDFGState, properties, InterstateEdge +from dace import SDFG, SDFGState, properties, InterstateEdge, Memlet, data as dt from dace.sdfg.graph import Edge from dace.sdfg import nodes as nd from dace.sdfg.analysis import cfg @@ -505,3 +505,81 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i del result[desc][write] top_result[sdfg.sdfg_id] = result return top_result + + +@properties.make_properties +class AccessRanges(ppl.Pass): + """ + For each data descriptor, finds all memlets used to access it (read/write ranges). + """ + + CATEGORY: str = 'Analysis' + + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.Nothing + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return modified & ppl.Modifies.Memlets + + def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[Memlet]]]: + """ + :return: A dictionary mapping each data descriptor name to a set of memlets. + """ + top_result: Dict[int, Dict[str, Set[Memlet]]] = dict() + + for sdfg in top_sdfg.all_sdfgs_recursive(): + result: Dict[str, Set[Memlet]] = defaultdict(set) + for state in sdfg.states(): + for anode in state.data_nodes(): + for e in state.all_edges(anode): + if e.dst is anode and e.dst_conn == 'set': # Skip reference sets + continue + if e.data.is_empty(): # Skip empty memlets + continue + # Find (hopefully propagated) root memlet + e = state.memlet_tree(e).root().edge + result[anode.data].add(e.data) + top_result[sdfg.sdfg_id] = result + return top_result + + +@properties.make_properties +class FindReferenceSources(ppl.Pass): + """ + For each Reference data descriptor, finds all memlets used to set it. If a Tasklet was used + to set the reference, the Tasklet is given as a source. + """ + + CATEGORY: str = 'Analysis' + + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.Nothing + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return modified & ppl.Modifies.Memlets + + def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[Union[Memlet, nd.CodeNode]]]]: + """ + :return: A dictionary mapping each data descriptor name to a set of memlets. + """ + top_result: Dict[int, Dict[str, Set[Union[Memlet, nd.CodeNode]]]] = dict() + + for sdfg in top_sdfg.all_sdfgs_recursive(): + result: Dict[str, Set[Memlet]] = defaultdict(set) + reference_descs = set(k for k, v in sdfg.arrays.items() if isinstance(v, dt.Reference)) + for state in sdfg.states(): + for anode in state.data_nodes(): + if anode.data not in reference_descs: + continue + for e in state.in_edges(anode): + if e.dst_conn != 'set': + continue + true_src = state.memlet_path(e)[0].src + if isinstance(true_src, nd.CodeNode): + # Code -> Reference + result[anode.data].add(true_src) + else: + # Array -> Reference + result[anode.data].add(e.data) + top_result[sdfg.sdfg_id] = result + return top_result diff --git a/tests/passes/access_ranges_test.py b/tests/passes/access_ranges_test.py new file mode 100644 index 0000000000..263cb2243d --- /dev/null +++ b/tests/passes/access_ranges_test.py @@ -0,0 +1,61 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" Tests the AccessRanges analysis pass. """ +import dace +from dace.transformation.passes.analysis import AccessRanges +import numpy as np + +N = dace.symbol('N') + + +def test_simple(): + + @dace.program + def tester(A: dace.float64[N, N], B: dace.float64[20, 20]): + for i, j in dace.map[0:20, 0:N]: + A[i, j] = 1 + + sdfg = tester.to_sdfg(simplify=True) + ranges = AccessRanges().apply_pass(sdfg, {}) + assert len(ranges) == 1 # Only one SDFG + ranges = ranges[0] + assert len(ranges) == 1 # Only one array is accessed + + # Construct write memlet + memlet = dace.Memlet('A[0:20, 0:N]') + memlet._is_data_src = False + + assert ranges['A'] == {memlet} + + +def test_simple_ranges(): + + @dace.program + def tester(A: dace.float64[N, N], B: dace.float64[20, 20]): + A[:, :] = 0 + A[1:21, 1:21] = B + A[0, 0] += 1 + + sdfg = tester.to_sdfg(simplify=True) + ranges = AccessRanges().apply_pass(sdfg, {}) + assert len(ranges) == 1 # Only one SDFG + ranges = ranges[0] + assert len(ranges) == 2 # Two arrays are accessed + + assert len(ranges['B']) == 1 + assert next(iter(ranges['B'])).src_subset == dace.subsets.Range([(0, 19, 1), (0, 19, 1)]) + + # Construct read/write memlets + memlet1 = dace.Memlet('A[0:N, 0:N]') + memlet1._is_data_src = False + memlet2 = dace.Memlet('A[1:21, 1:21] -> 0:20, 0:20') + memlet2._is_data_src = False + memlet3 = dace.Memlet('A[0, 0]') + memlet4 = dace.Memlet('A[0, 0]') + memlet4._is_data_src = False + + assert ranges['A'] == {memlet1, memlet2, memlet3, memlet4} + + +if __name__ == '__main__': + test_simple() + test_simple_ranges() diff --git a/tests/sdfg/reference_test.py b/tests/sdfg/reference_test.py index 3f2cfb685c..f1e605e315 100644 --- a/tests/sdfg/reference_test.py +++ b/tests/sdfg/reference_test.py @@ -1,10 +1,11 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ Tests the use of Reference data descriptors. """ import dace +from dace.transformation.passes.analysis import FindReferenceSources import numpy as np -def test_reference_branch(): +def _create_branch_sdfg(): sdfg = dace.SDFG('refbranch') sdfg.add_array('A', [20], dace.float64) sdfg.add_array('B', [20], dace.float64) @@ -29,6 +30,11 @@ def test_reference_branch(): r = finish.add_read('ref') w = finish.add_write('out') finish.add_nedge(r, w, dace.Memlet('ref')) + return sdfg + + +def test_reference_branch(): + sdfg = _create_branch_sdfg() A = np.random.rand(20) B = np.random.rand(20) @@ -41,5 +47,16 @@ def test_reference_branch(): assert np.allclose(out, A) +def test_reference_sources_pass(): + sdfg = _create_branch_sdfg() + sources = FindReferenceSources().apply_pass(sdfg, {}) + assert len(sources) == 1 # There is only one SDFG + sources = sources[0] + assert len(sources) == 1 and 'ref' in sources # There is one reference + sources = sources['ref'] + assert sources == {dace.Memlet('A[0:20]', volume=1), dace.Memlet('B[0:20]', volume=1)} + + if __name__ == '__main__': test_reference_branch() + test_reference_sources_pass()