From dbf6cafefcf60b4ab97bf57754733991b2496c87 Mon Sep 17 00:00:00 2001 From: Yakup Budanaz Date: Thu, 28 Nov 2024 12:06:53 +0100 Subject: [PATCH] Decrease code duplication --- tests/transformations/gpu_transform_test.py | 43 ++++++--------------- 1 file changed, 12 insertions(+), 31 deletions(-) diff --git a/tests/transformations/gpu_transform_test.py b/tests/transformations/gpu_transform_test.py index 779f41f99b..62f95dcdb9 100644 --- a/tests/transformations/gpu_transform_test.py +++ b/tests/transformations/gpu_transform_test.py @@ -118,41 +118,22 @@ def write_subset_dynamic(A: dace.int32[20, 20], x: dace.int32[20], y: dace.int32 assert np.array_equal(ref, val) - -@pytest.mark.parametrize("transient", [False, True]) -def test_free_tasklet_and_array(transient): +@pytest.mark.parametrize(["transient", "scalar"], + [[False, False], [False, True], + [True, False], [True, True]]) +def test_free_tasklet(transient, scalar): sdfg = dace.SDFG("assign") state = sdfg.add_state("main") - arr_name, arr = sdfg.add_array("A", (4,), dace.float32, transient=transient) - an = state.add_access(arr_name) - - t = state.add_tasklet("assign", {}, {"_out"}, "_out = 2.0") - state.add_edge(t, "_out", an, None, dace.memlet.Memlet("A[0]")) - - sdfg.validate() - - sdfg.apply_gpu_transformations( - validate = True, - validate_all = True, - permissive = True, - sequential_innermaps=True, - register_transients=False, - simplify=False - ) + if scalar: + arr_name, arr = sdfg.add_scalar("A", dace.float32, transient=transient) + else: + arr_name, arr = sdfg.add_array("A", (4,), dace.float32, transient=transient) - sdfg.validate() - -@pytest.mark.parametrize("transient", [False, True]) -def test_free_tasklet_and_scalar(transient): - sdfg = dace.SDFG("assign") - - state = sdfg.add_state("main") - arr_name, arr = sdfg.add_scalar("A", dace.float32, transient=transient) an = state.add_access(arr_name) t = state.add_tasklet("assign", {}, {"_out"}, "_out = 2.0") - state.add_edge(t, "_out", an, None, dace.memlet.Memlet("A")) + state.add_edge(t, "_out", an, None, dace.memlet.Memlet("A" if scalar else "A[0]")) sdfg.validate() @@ -173,6 +154,6 @@ def test_free_tasklet_and_scalar(transient): test_write_subset() test_write_full() test_write_subset_dynamic() - for transient in [False, True]: - test_free_tasklet_and_array(transient) - test_free_tasklet_and_array(transient) + for scalar in [False, True]: + for transient in [False, True]: + test_free_tasklet(transient, scalar)