Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
wenscarl committed Aug 8, 2024
1 parent 4d9d622 commit 8105930
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
2 changes: 1 addition & 1 deletion jax/_src/cudnn/fused_attention_stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def check_eq(a, b, c, msg):
raise ValueError(f"kv_seqlen must have same batch as Q, got {kv_seq_b}")

def check_is_flash_attention(
query, key, layout, cudnn_version, has_bias, is_training):
query, key, layout: int, cudnn_version, has_bias, is_training):
if layout == AttentionLayout.BNTH.value:
_, _, T, H = query.shape
_, _, S, _ = key.shape
Expand Down
23 changes: 16 additions & 7 deletions tests/fused_attention_stablehlo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,18 +429,27 @@ def _cvt_back(x):

def test_sdpa_utils(self):
test_cases = [
(1, 257, 64, 8905, False, True),
(1, 1024, 64, 8905, False, False),
(1024, 1024, 64, 8905, False, False),
(1024, 1024, 128, 8905, False, False),
(1, 257, 64, 8905, False, True, True),
(1, 1024, 64, 8905, False, False, True),
(1024, 1024, 64, 8905, False, False, True),
(1024, 1024, 128, 8905, False, False, True),
(1024, 1024, 127, 8905, False, False, False),
]

for k in test_cases:
sql_q, sql_v, head_dim, cudnn_version, has_bias, is_training = k
sql_q, sql_v, head_dim, cudnn_version, has_bias, is_training, \
expected_pass = k
query = jnp.empty((4, sql_q, 4, head_dim))
key = jnp.empty((4, sql_v, 4, head_dim))
check_is_flash_attention(
query, key, AttentionLayout.BNTH, cudnn_version, has_bias, is_training)
if expected_pass:
check_is_flash_attention(
query, key, AttentionLayout.BNTH.value, cudnn_version, has_bias,
is_training)
else:
with self.assertRaises(NotImplementedError):
check_is_flash_attention(
query, key, AttentionLayout.BNTH.value, cudnn_version, has_bias,
is_training)

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 8105930

Please sign in to comment.