diff --git a/dace/transformation/passes/analysis/propagation.py b/dace/transformation/passes/analysis/propagation.py index d16629a1ed..defa12eedb 100644 --- a/dace/transformation/passes/analysis/propagation.py +++ b/dace/transformation/passes/analysis/propagation.py @@ -12,7 +12,6 @@ from dace.memlet import Memlet from dace.sdfg import nodes from dace.sdfg.analysis import cfg as cfg_analysis -from dace.sdfg.graph import MultiConnectorEdge from dace.sdfg.propagation import align_memlet, propagate_memlet from dace.sdfg.scope import ScopeTree from dace.sdfg.sdfg import SDFG, memlets_in_ast @@ -292,8 +291,7 @@ def _propagate_nsdfg(self, parent_sdfg: SDFG, parent_state: SDFGState, nsdfg_nod iedge.data.volume = 0 iedge.data.dynamic = True except (ValueError, NotImplementedError): - # In any case of memlets that cannot be unsqueezed (i.e., - # reshapes), use dynamic unbounded memlets. + # In any case of memlets that cannot be unsqueezed (i.e., reshapes), use dynamic unbounded memlets. iedge.data.volume = 0 iedge.data.dynamic = True for oedge in parent_state.out_edges(nsdfg_node): @@ -312,8 +310,7 @@ def _propagate_nsdfg(self, parent_sdfg: SDFG, parent_state: SDFGState, nsdfg_nod oedge.data.volume = 0 oedge.data.dynamic = True except (ValueError, NotImplementedError): - # In any case of memlets that cannot be unsqueezed (i.e., - # reshapes), use dynamic unbounded memlets. + # In any case of memlets that cannot be unsqueezed (i.e., reshapes), use dynamic unbounded memlets. oedge.data.volume = 0 oedge.data.dynamic = True diff --git a/tests/passes/analysis/propagation_test.py b/tests/passes/analysis/propagation_test.py index 53a320503d..e1ca0ec2a9 100644 --- a/tests/passes/analysis/propagation_test.py +++ b/tests/passes/analysis/propagation_test.py @@ -6,8 +6,13 @@ def test_nested_conditional_in_map(): """ - Thanks to the else branch, propagation should correctly identify that only A[0] is being read with volume 1 (in the - condition check of the branch), and A[i, :] (volume M) is written for each map iteration. + Thanks to the else branch, propagation should correctly identify that only A[0, 0] is being read with volume 1 (in + the condition check of the branch), and A[i, :] (volume N) is written for each map iteration. + NOTE: Due to view-based NSDFGs, currently the read is actually A[0, :] with volume N because of the way the nested + SDFG is constructed. The entire A[0] slice is passed to the nested SDFG and then the read to A[0, 0] happens + on an interstate edge inside the nested SDFG. This analysis correctly identifies the subset being read, but + the volume is technically wrong for now. This will be resolved when no-view-NSDFGs are introduced. + TODO: Revisit when no-view-NSDFGs are introduced. """ N = dace.symbol('N') M = dace.symbol('M') @@ -44,7 +49,7 @@ def test_nested_conditional_in_loop_in_map(): @dace.program def nested_conditional_in_loop_in_map(A: dace.int32[M, N]): for i in dace.map[0:M]: - for j in range(2, N, 1): + for j in range(0, N - 2, 1): if A[0][0]: A[i, j] = 1 else: @@ -60,9 +65,9 @@ def nested_conditional_in_loop_in_map(A: dace.int32[M, N]): assert 'A' in sdfg._certain_reads assert str(sdfg._certain_reads['A']) == 'A[0, 0]' assert 'A' in sdfg._possible_writes - assert str(sdfg._possible_writes['A']) == 'A[0:M, 0:N]' + assert str(sdfg._possible_writes['A']) == 'A[0:M, 0:N - 2]' assert 'A' in sdfg._certain_writes - assert str(sdfg._certain_writes['A']) == 'A[0:M, 0:N]' + assert str(sdfg._certain_writes['A']) == 'A[0:M, 0:N - 2]' def test_2D_map_added_indices(): """