diff --git a/dace/transformation/passes/analysis/propagation.py b/dace/transformation/passes/analysis/propagation.py index 662e6b1b8f..d3d5545c34 100644 --- a/dace/transformation/passes/analysis/propagation.py +++ b/dace/transformation/passes/analysis/propagation.py @@ -16,7 +16,7 @@ from dace.sdfg.scope import ScopeTree from dace.sdfg.sdfg import SDFG, memlets_in_ast from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion, SDFGState -from dace.subsets import Range, Subset, SubsetUnion +from dace.subsets import Range, SubsetUnion from dace.transformation import pass_pipeline as ppl from dace.transformation import transformation from dace.transformation.helpers import unsqueeze_memlet @@ -331,7 +331,7 @@ def _propagate_state(self, state: SDFGState) -> None: for iedge in state.in_edges(anode): if not iedge.data.is_empty(): root_edge = state.memlet_tree(iedge).root().edge - writes[anode.data].append([root_edge.data, anode]) + writes[anode.data].append((root_edge.data, anode)) # Go over (overapproximated) reads and check if they are covered by writes. not_covered_reads: Dict[str, Set[Memlet]] = defaultdict(set) @@ -363,27 +363,28 @@ def _propagate_state(self, state: SDFGState) -> None: state._certain_writes = {} state._possible_writes = {} for data in writes: - subset = None - volume = None - is_dynamic = False - for memlet, _ in writes[data]: - is_dynamic |= memlet.dynamic - if subset is None: - subset = SubsetUnion(memlet.dst_subset or memlet.subset) - else: - subset.union(memlet.dst_subset or memlet.subset) - if memlet.volume == 0: - volume = 0 - else: - if volume is None: - volume = memlet.volume - elif volume != 0: - volume += memlet.volume - new_memlet = Memlet(data=data, subset=subset) - new_memlet.dynamic = is_dynamic - new_memlet.volume = volume - state._certain_writes[data] = new_memlet - state._possible_writes[data] = new_memlet + if len(writes[data]) > 0: + subset = None + volume = None + is_dynamic = False + for memlet, _ in writes[data]: + is_dynamic |= memlet.dynamic + if subset is None: + subset = SubsetUnion(memlet.dst_subset or memlet.subset) + else: + subset.union(memlet.dst_subset or memlet.subset) + if memlet.volume == 0: + volume = 0 + else: + if volume is None: + volume = memlet.volume + elif volume != 0: + volume += memlet.volume + new_memlet = Memlet(data=data, subset=subset) + new_memlet.dynamic = is_dynamic + new_memlet.volume = volume if volume is not None else 0 + state._certain_writes[data] = new_memlet + state._possible_writes[data] = new_memlet state._certain_reads = {} state._possible_reads = {} diff --git a/tests/passes/analysis/propagation_test.py b/tests/passes/analysis/propagation_test.py index 83be47dd07..7928dc6a00 100644 --- a/tests/passes/analysis/propagation_test.py +++ b/tests/passes/analysis/propagation_test.py @@ -5,15 +5,6 @@ def test_nested_conditional_in_map(): - """ - Thanks to the else branch, propagation should correctly identify that only A[0, 0] is being read with volume 1 (in - the condition check of the branch), and A[i, :] (volume N) is written for each map iteration. - NOTE: Due to view-based NSDFGs, currently the read is actually A[0, :] with volume N because of the way the nested - SDFG is constructed. The entire A[0] slice is passed to the nested SDFG and then the read to A[0, 0] happens - on an interstate edge inside the nested SDFG. This analysis correctly identifies the subset being read, but - the volume is technically wrong for now. This will be resolved when no-view-NSDFGs are introduced. - TODO: Revisit when no-view-NSDFGs are introduced. - """ N = dace.symbol('N') M = dace.symbol('M') @@ -47,10 +38,6 @@ def nested_conditional_in_map(A: dace.int32[M, N]): assert sdfg._certain_writes['A'].volume == M * N def test_nested_conditional_in_loop_in_map(): - """ - Write in nested SDFG in two-dimensional map nest. - Nested map does not iterate over shape of second array dimension. - --> should approximate write-set of map nest precisely.""" N = dace.symbol('N') M = dace.symbol('M') @@ -85,35 +72,110 @@ def nested_conditional_in_loop_in_map(A: dace.int32[M, N]): assert sdfg._certain_writes['A'].dynamic == False assert sdfg._certain_writes['A'].volume == M * (N - 2) -def test_2D_map_added_indices(): - """ - 2-dimensional array that writes to two-dimensional array with - subscript expression that adds two indices - --> Approximated write-set of Map is empty - """ +def test_runtime_conditional(): + @dace.program + def rconditional(in1: dace.float64[10], out: dace.float64[10], mask: dace.int32[10]): + for i in dace.map[1:10]: + if mask[i] > 0: + out[i] = in1[i - 1] + else: + out[i] = in1[i] + + sdfg = rconditional.to_sdfg(simplify=True) + + MemletPropagation().apply_pass(sdfg, {}) + + assert 'mask' in sdfg._possible_reads + assert str(sdfg._possible_reads['mask'].subset) == '1:10' + assert sdfg._possible_reads['mask'].dynamic == False + assert sdfg._possible_reads['mask'].volume == 9 + assert 'in1' in sdfg._possible_reads + assert str(sdfg._possible_reads['in1'].subset) == '0:10' + assert sdfg._possible_reads['in1'].dynamic == False + assert sdfg._possible_reads['in1'].volume == 18 + + assert 'mask' in sdfg._certain_reads + assert str(sdfg._certain_reads['mask'].subset) == '1:10' + assert sdfg._certain_reads['mask'].dynamic == False + assert sdfg._certain_reads['mask'].volume == 9 + assert 'in1' in sdfg._certain_reads + assert str(sdfg._certain_reads['in1'].subset) == '0:10' + assert sdfg._certain_reads['in1'].dynamic == False + assert sdfg._certain_reads['in1'].volume == 18 + + assert 'out' in sdfg._possible_writes + assert str(sdfg._possible_writes['out'].subset) == '1:10' + assert sdfg._possible_writes['out'].dynamic == False + assert sdfg._possible_writes['out'].volume == 9 + + assert 'out' in sdfg._certain_writes + assert str(sdfg._certain_writes['out'].subset) == '1:10' + assert sdfg._certain_writes['out'].dynamic == False + assert sdfg._certain_writes['out'].volume == 9 + +def test_nsdfg_memlet_propagation_with_one_sparse_dimension(): + N = dace.symbol('N') + M = dace.symbol('M') + @dace.program + def sparse(A: dace.float32[M, N], ind: dace.int32[M, N]): + for i, j in dace.map[0:M, 0:N]: + A[i, ind[i, j]] += 1 + + sdfg = sparse.to_sdfg(simplify=False) + + MemletPropagation().apply_pass(sdfg, {}) + + assert 'ind' in sdfg._possible_reads + assert str(sdfg._possible_reads['ind'].subset) == '0:M, 0:N' + assert sdfg._possible_reads['ind'].dynamic == False + assert sdfg._possible_reads['ind'].volume == N * M + + assert 'ind' in sdfg._certain_reads + assert str(sdfg._certain_reads['ind'].subset) == '0:M, 0:N' + assert sdfg._certain_reads['ind'].dynamic == False + assert sdfg._certain_reads['ind'].volume == N * M + + assert 'A' in sdfg._possible_writes + assert str(sdfg._possible_writes['A'].subset) == '0:M, 0:N' + assert sdfg._possible_writes['A'].dynamic == False + assert sdfg._possible_writes['A'].volume == N * M + + assert 'A' in sdfg._certain_writes + assert str(sdfg._certain_writes['A'].subset) == '0:M, 0:N' + assert sdfg._certain_writes['A'].dynamic == False + assert sdfg._certain_writes['A'].volume == N * M +def test_nested_loop_in_map(): N = dace.symbol('N') M = dace.symbol('M') - sdfg = dace.SDFG("twoD_map") - sdfg.add_array("B", (M, N), dace.float64) - map_state = sdfg.add_state("map") - a1 = map_state.add_access('B') - map_state.add_mapped_tasklet("overwrite_1", - map_ranges={ - '_i': '0:N:1', - '_j': '0:M:1' - }, - inputs={}, - code="b = 5", - outputs={"b": dace.Memlet("B[_j,_i + _j]")}, - output_nodes={"B": a1}, - external_edges=True) + @dace.program + def nested_loop_in_map(A: dace.float64[N, M]): + for i in dace.map[0:N]: + for j in range(M): + A[i, j] = 0 + + sdfg = nested_loop_in_map.to_sdfg(simplify=True) - print(sdfg) + MemletPropagation().apply_pass(sdfg, {}) + + assert sdfg._possible_reads == {} + assert sdfg._certain_reads == {} + + assert 'A' in sdfg._possible_writes + assert str(sdfg._possible_writes['A'].subset) == '0:N, 0:M' + assert sdfg._possible_writes['A'].dynamic == False + assert sdfg._possible_writes['A'].volume == N * M + + assert 'A' in sdfg._certain_writes + assert str(sdfg._certain_writes['A'].subset) == '0:N, 0:M' + assert sdfg._certain_writes['A'].dynamic == False + assert sdfg._certain_writes['A'].volume == N * M if __name__ == '__main__': test_nested_conditional_in_map() test_nested_conditional_in_loop_in_map() - #test_2D_map_added_indices() + test_runtime_conditional() + test_nsdfg_memlet_propagation_with_one_sparse_dimension() + test_nested_loop_in_map()