Skip to content

Commit

Permalink
Merge pull request #22882 from wenscarl:attn_layout_fix
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 668636119
  • Loading branch information
jax authors committed Aug 28, 2024
2 parents 28a6558 + 8105930 commit 2785a08
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
4 changes: 2 additions & 2 deletions jax/_src/cudnn/fused_attention_stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,8 @@ def check_eq(a, b, c, msg):
raise ValueError(f"kv_seqlen must have same batch as Q, got {kv_seq_b}")

def check_is_flash_attention(
query, key, layout, cudnn_version, has_bias, is_training):
if layout == AttentionLayout.BNTH:
query, key, layout: int, cudnn_version, has_bias, is_training):
if layout == AttentionLayout.BNTH.value:
_, _, T, H = query.shape
_, _, S, _ = key.shape
else:
Expand Down
23 changes: 16 additions & 7 deletions tests/fused_attention_stablehlo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,18 +429,27 @@ def _cvt_back(x):

def test_sdpa_utils(self):
test_cases = [
(1, 257, 64, 8905, False, True),
(1, 1024, 64, 8905, False, False),
(1024, 1024, 64, 8905, False, False),
(1024, 1024, 128, 8905, False, False),
(1, 257, 64, 8905, False, True, True),
(1, 1024, 64, 8905, False, False, True),
(1024, 1024, 64, 8905, False, False, True),
(1024, 1024, 128, 8905, False, False, True),
(1024, 1024, 127, 8905, False, False, False),
]

for k in test_cases:
sql_q, sql_v, head_dim, cudnn_version, has_bias, is_training = k
sql_q, sql_v, head_dim, cudnn_version, has_bias, is_training, \
expected_pass = k
query = jnp.empty((4, sql_q, 4, head_dim))
key = jnp.empty((4, sql_v, 4, head_dim))
check_is_flash_attention(
query, key, AttentionLayout.BNTH, cudnn_version, has_bias, is_training)
if expected_pass:
check_is_flash_attention(
query, key, AttentionLayout.BNTH.value, cudnn_version, has_bias,
is_training)
else:
with self.assertRaises(NotImplementedError):
check_is_flash_attention(
query, key, AttentionLayout.BNTH.value, cudnn_version, has_bias,
is_training)

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 2785a08

Please sign in to comment.