diff --git a/tests/transformations/const_assignment_fusion_test.py b/tests/transformations/const_assignment_fusion_test.py index 8b3e072232..f435050a82 100644 --- a/tests/transformations/const_assignment_fusion_test.py +++ b/tests/transformations/const_assignment_fusion_test.py @@ -267,6 +267,68 @@ def test_fusion_with_branch(): assert np.allclose(our_A, actual_A) +@dace.program +def assign_bottom_face_flipped(A: dace.float32[K, M, N]): + for t2, t1 in dace.map[0:N, 0:M]: + A[K - 1, t1, t2] = 1 + + +@dace.program +def assign_bounary_3d_with_flip(A: dace.float32[K, M, N], B: dace.float32[K, M, N]): + assign_top_face(A) + assign_bottom_face_flipped(B) + + +def test_does_not_permute_to_fuse(): + """ Negative test """ + A = np.random.uniform(size=(3, 4, 5)).astype(np.float32) + B = np.random.uniform(size=(3, 4, 5)).astype(np.float32) + + # Construct SDFG with the maps on separate states. + g = assign_bounary_3d_with_flip.to_sdfg(simplify=True, validate=True, use_cache=False) + g.apply_transformations_repeated(StateFusionExtended, validate_all=True) + g.save(os.path.join('_dacegraphs', '3d-flip-0.sdfg')) + g.validate() + actual_A = deepcopy(A) + actual_B = deepcopy(B) + g(A=actual_A, B=actual_B, K=3, M=4, N=5) + + assert g.apply_transformations_repeated(ConstAssignmentMapFusion) == 0 + g.save(os.path.join('_dacegraphs', '3d-flip-1.sdfg')) + g.validate() + our_A = deepcopy(A) + our_B = deepcopy(B) + g(A=our_A, B=our_B, K=3, M=4, N=5) + + +@dace.program +def assign_mixed_dims(A: dace.float32[K, M, N], B: dace.float32[K, M, N]): + assign_top_face(A) + assign_left_col(B[0, :, :]) + + +def test_does_not_extend_to_fuse(): + """ Negative test """ + A = np.random.uniform(size=(3, 4, 5)).astype(np.float32) + B = np.random.uniform(size=(3, 4, 5)).astype(np.float32) + + # Construct SDFG with the maps on separate states. + g = assign_mixed_dims.to_sdfg(simplify=True, validate=True, use_cache=False) + g.apply_transformations_repeated(StateFusionExtended, validate_all=True) + g.save(os.path.join('_dacegraphs', '3d-mixed-0.sdfg')) + g.validate() + actual_A = deepcopy(A) + actual_B = deepcopy(B) + g(A=actual_A, B=actual_B, K=3, M=4, N=5) + + assert g.apply_transformations_repeated(ConstAssignmentMapFusion) == 0 + g.save(os.path.join('_dacegraphs', '3d-mixed-1.sdfg')) + g.validate() + our_A = deepcopy(A) + our_B = deepcopy(B) + g(A=our_A, B=our_B, K=3, M=4, N=5) + + if __name__ == '__main__': test_within_state_fusion() test_interstate_fusion()