From f95f8162a4e77d7a386ccd20c9e4ef71a3ad9787 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 4 Sep 2023 23:58:33 -0700 Subject: [PATCH] Fix dynamic memlet propagation condition (#1364) --- dace/sdfg/propagation.py | 4 ++-- tests/python_frontend/argument_test.py | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index 0fec4812b7..0554775dcd 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -1477,8 +1477,8 @@ def propagate_subset(memlets: List[Memlet], new_memlet.volume = simplify(sum(m.volume for m in memlets) * functools.reduce(lambda a, b: a * b, rng.size(), 1)) if any(m.dynamic for m in memlets): new_memlet.dynamic = True - elif symbolic.issymbolic(new_memlet.volume) and any(s not in defined_variables - for s in new_memlet.volume.free_symbols): + if symbolic.issymbolic(new_memlet.volume) and any(s not in defined_variables + for s in new_memlet.volume.free_symbols): new_memlet.dynamic = True new_memlet.volume = 0 diff --git a/tests/python_frontend/argument_test.py b/tests/python_frontend/argument_test.py index 1f43337eb8..cb47188029 100644 --- a/tests/python_frontend/argument_test.py +++ b/tests/python_frontend/argument_test.py @@ -2,6 +2,7 @@ import dace import pytest +import numpy as np N = dace.symbol('N') @@ -16,5 +17,29 @@ def test_extra_args(): imgcpy([[1, 2], [3, 4]], [[4, 3], [2, 1]], 0.0, 1.0) +def test_missing_arguments_regression(): + + def nester(a, b, T): + for i, j in dace.map[0:20, 0:20]: + start = 0 + end = min(T, 6) + + elem: dace.float64 = 0 + for ii in range(start, end): + if ii % 2 == 0: + elem += b[ii] + + a[j, i] = elem + + @dace.program + def tester(x: dace.float64[20, 20]): + gdx = np.ones((10, ), dace.float64) + for T in range(2): + nester(x, gdx, T) + + tester.to_sdfg().compile() + + if __name__ == '__main__': test_extra_args() + test_missing_arguments_regression()