Skip to content

Commit

Permalink
Add two designated "negative tests".
Browse files Browse the repository at this point in the history
  • Loading branch information
pratyai committed Oct 28, 2024
1 parent db94444 commit 0cd919f
Showing 1 changed file with 62 additions and 0 deletions.
62 changes: 62 additions & 0 deletions tests/transformations/const_assignment_fusion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,68 @@ def test_fusion_with_branch():
assert np.allclose(our_A, actual_A)


@dace.program
def assign_bottom_face_flipped(A: dace.float32[K, M, N]):
for t2, t1 in dace.map[0:N, 0:M]:
A[K - 1, t1, t2] = 1


@dace.program
def assign_bounary_3d_with_flip(A: dace.float32[K, M, N], B: dace.float32[K, M, N]):
assign_top_face(A)
assign_bottom_face_flipped(B)


def test_does_not_permute_to_fuse():
""" Negative test """
A = np.random.uniform(size=(3, 4, 5)).astype(np.float32)
B = np.random.uniform(size=(3, 4, 5)).astype(np.float32)

# Construct SDFG with the maps on separate states.
g = assign_bounary_3d_with_flip.to_sdfg(simplify=True, validate=True, use_cache=False)
g.apply_transformations_repeated(StateFusionExtended, validate_all=True)
g.save(os.path.join('_dacegraphs', '3d-flip-0.sdfg'))
g.validate()
actual_A = deepcopy(A)
actual_B = deepcopy(B)
g(A=actual_A, B=actual_B, K=3, M=4, N=5)

assert g.apply_transformations_repeated(ConstAssignmentMapFusion) == 0
g.save(os.path.join('_dacegraphs', '3d-flip-1.sdfg'))
g.validate()
our_A = deepcopy(A)
our_B = deepcopy(B)
g(A=our_A, B=our_B, K=3, M=4, N=5)


@dace.program
def assign_mixed_dims(A: dace.float32[K, M, N], B: dace.float32[K, M, N]):
assign_top_face(A)
assign_left_col(B[0, :, :])


def test_does_not_extend_to_fuse():
""" Negative test """
A = np.random.uniform(size=(3, 4, 5)).astype(np.float32)
B = np.random.uniform(size=(3, 4, 5)).astype(np.float32)

# Construct SDFG with the maps on separate states.
g = assign_mixed_dims.to_sdfg(simplify=True, validate=True, use_cache=False)
g.apply_transformations_repeated(StateFusionExtended, validate_all=True)
g.save(os.path.join('_dacegraphs', '3d-mixed-0.sdfg'))
g.validate()
actual_A = deepcopy(A)
actual_B = deepcopy(B)
g(A=actual_A, B=actual_B, K=3, M=4, N=5)

assert g.apply_transformations_repeated(ConstAssignmentMapFusion) == 0
g.save(os.path.join('_dacegraphs', '3d-mixed-1.sdfg'))
g.validate()
our_A = deepcopy(A)
our_B = deepcopy(B)
g(A=our_A, B=our_B, K=3, M=4, N=5)


if __name__ == '__main__':
test_within_state_fusion()
test_interstate_fusion()
Expand Down

0 comments on commit 0cd919f

Please sign in to comment.