Skip to content

Commit

Permalink
PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_at…
Browse files Browse the repository at this point in the history
…tention`

Imported from GitHub PR #23223

While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.

For the new tests, we categorize them as follows:

1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.

Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:

--
dd2ca19 by kaixih <kaixih@nvidia.com>:

Reduce attn tests

Merging this change closes #23223

COPYBARA_INTEGRATE_REVIEW=#23223 from kaixih:reduce_attn_tests dd2ca19
PiperOrigin-RevId: 669364738
  • Loading branch information
kaixih authored and jax authors committed Aug 30, 2024
1 parent e494de8 commit 8ccc439
Showing 1 changed file with 70 additions and 103 deletions.
173 changes: 70 additions & 103 deletions tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,149 +45,116 @@ def _is_required_cudnn_version_satisfied():
cuda_versions.cudnn_get_version() >= 8904
)

def _get_causal_mask(T, S):
causal_mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_))
return causal_mask[jnp.newaxis, jnp.newaxis, :, :]
def _check_cudnn_backend(fn, *args, **kwargs):
lowered = jax.jit(fn).lower(*args, **kwargs)
hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
return '__cudnn$fmha' in hlo

@jtu.with_config(jax_legacy_prng_key="allow",
jax_numpy_dtype_promotion="standard")
class NNFunctionsTest(jtu.JaxTestCase):
@parameterized.product(
dtype=[jnp.float32, jnp.bfloat16, jnp.float16],
use_bias=[False, True],
causal_mode=[None, 'attr', 'mask'],
dtype=[jnp.bfloat16, jnp.float16],
group_num=[1, 2, 4],
use_vmap=[False, True],
use_seqlen=[False, True],
impl=['xla', 'cudnn'],
impl=['cudnn', 'xla'],
)
def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode,
group_num, use_vmap, use_seqlen, impl):
def testDotProductAttention(self, dtype, group_num, use_vmap, impl):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied():
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
if impl == 'cudnn' and dtype == jnp.float32:
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.")
if use_vmap and use_seqlen:
raise unittest.SkipTest("vmap cannot be used together with variable "
"seqence lengths")

sdpa = nn.dot_product_attention
B, S, T, N, H, G = 2, 128, 128, 4, 32, group_num
keys = random.split(random.PRNGKey(0), 4)
keys = random.split(random.PRNGKey(0), 5)
Q = random.normal(keys[0], (B, T, N, H), dtype)
K = random.normal(keys[1], (B, S, N // G, H), dtype)
V = random.normal(keys[2], (B, S, N // G, H), dtype)
if use_bias:
bias = random.normal(keys[3], (1, N, T, S), dtype)
else:
bias = None
if use_seqlen:
q_seqlen = jnp.array([T // 2, T // 4], dtype=jnp.int32)
kv_seqlen = jnp.array([S // 4, S // 2], dtype=jnp.int32)
else:
q_seqlen = None
kv_seqlen = None
grad = random.normal(keys[3], (B, T, N, H), dtype)
bias, mask = None, None

is_causal = causal_mode == 'attr'
causal_mask = _get_causal_mask(T, S) if causal_mode == 'mask' else None
sdpa = nn.dot_product_attention
sdpa_ref = partial(sdpa, implementation=None)
sdpa_ans = partial(sdpa, implementation=impl)
if use_vmap:
sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0)

sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None)
sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl)
# For testing purposes, we call the non-GQA version without vmap in the
# reference code
K_ref = jnp.repeat(K, G, axis=2)
V_ref = jnp.repeat(V, G, axis=2)
out_ref, sdpa_vjp_ref = jax.vjp(sdpa_ref, Q, K_ref, V_ref, bias, mask)
out_ans, sdpa_vjp_ans = jax.vjp(sdpa_ans, Q, K, V, bias, mask)

dQ_ref, dK_ref, dV_ref = sdpa_vjp_ref(grad)[:3]
dQ_ans, dK_ans, dV_ans = sdpa_vjp_ans(grad)[:3]
dK_ref = dK_ref.reshape(B, S, N // G, G, H).sum(axis=3)
dV_ref = dV_ref.reshape(B, S, N // G, G, H).sum(axis=3)

if impl == 'cudnn':
lowered = jax.jit(sdpa_ans).lower(Q, K, V, bias, causal_mask,
query_seq_lengths=q_seqlen,
key_value_seq_lengths=kv_seqlen)
hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
self.assertIn('__cudnn$fmha', hlo)

K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K
V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V
out_ref = sdpa_ref(Q, K_ref, V_ref, bias, causal_mask,
query_seq_lengths=q_seqlen,
key_value_seq_lengths=kv_seqlen)
if use_vmap:
sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0)
self.assertTrue(_check_cudnn_backend(sdpa_ans, Q, K, V, bias, mask))
self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad))

out_ans = sdpa_ans(Q, K, V, bias, causal_mask,
query_seq_lengths=q_seqlen,
key_value_seq_lengths=kv_seqlen)
self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01)
self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01)
self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02)
self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02)

@parameterized.product(
dtype=[jnp.float32, jnp.bfloat16, jnp.float16],
use_bias=[False, True],
causal_mode=[None, 'attr', 'mask'],
group_num=[1, 2, 4],
use_vmap=[False, True],
use_seqlen=[False, True],
impl=['xla', 'cudnn'],
mask_mode=['bias', 'causal', 'padding', 'custom', ('causal', 'padding'),
('custom', 'padding'), ('bias', 'causal')],
)
def testDotProductAttentionTrain(self, dtype, use_bias, causal_mode,
group_num, use_vmap, use_seqlen, impl):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied():
def testDotProductAttentionMask(self, mask_mode):
if not _is_required_cudnn_version_satisfied():
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
if impl == 'cudnn' and dtype == jnp.float32:
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.")
if use_vmap and use_seqlen:
raise unittest.SkipTest("vmap cannot be used together with variable "
"seqence lengths")
if use_seqlen and use_bias and impl == 'cudnn':
raise unittest.SkipTest("cudnn has limited support for dbias when using "
"variable seqence lengths")
if isinstance(mask_mode, str):
mask_mode = (mask_mode,)

sdpa = nn.dot_product_attention
B, S, T, N, H, G = 2, 128, 128, 4, 32, group_num
keys = random.split(random.PRNGKey(0), 5)
dtype = jnp.bfloat16
B, S, T, N, H = 2, 128, 128, 4, 32
keys = random.split(random.PRNGKey(0), 4)
Q = random.normal(keys[0], (B, T, N, H), dtype)
K = random.normal(keys[1], (B, S, N // G, H), dtype)
V = random.normal(keys[2], (B, S, N // G, H), dtype)
K = random.normal(keys[1], (B, S, N, H), dtype)
V = random.normal(keys[2], (B, S, N, H), dtype)
grad = random.normal(keys[3], (B, T, N, H), dtype)
if use_bias:
bias = random.normal(keys[4], (1, N, T, S), dtype)
else:
bias = None
if use_seqlen:
bias, mask = None, None
q_seqlen, kv_seqlen = None, None

is_causal = 'causal' in mask_mode
if 'padding' in mask_mode:
q_seqlen = jnp.array([T // 2, T // 4], dtype=jnp.int32)
kv_seqlen = jnp.array([S // 4, S // 2], dtype=jnp.int32)
else:
q_seqlen = None
kv_seqlen = None

is_causal = causal_mode == 'attr'
causal_mask = _get_causal_mask(T, S) if causal_mode == 'mask' else None
if 'custom' in mask_mode:
# Use a generated causal mask as the custom mask.
custom_mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_))
mask = custom_mask[None, None, :, :]
if 'bias' in mask_mode:
bias = random.normal(keys[4], (1, N, T, S), dtype)

K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K
V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V
sdpa = nn.dot_product_attention
sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None)
# Convert the keyword arguments to positional ones.
sdpa_ans = partial(sdpa, is_causal=is_causal, implementation='cudnn')

args = (Q, K, V, bias, mask)
kwargs = {'query_seq_lengths': q_seqlen, 'key_value_seq_lengths': kv_seqlen}

# Convert the kargs to positional args for the jax.vjp.
fn_ref = lambda q, k, v, b, m, qs, kvs: sdpa_ref(
q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs
q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs,
)
fn_ans = lambda q, k, v, b, m, qs, kvs: sdpa_ans(
q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs,
)
_, sdpa_vjp_ref = jax.vjp(fn_ref, Q, K_ref, V_ref, bias, causal_mask,
q_seqlen, kv_seqlen)
out_ref, sdpa_vjp_ref = jax.vjp(fn_ref, *args, q_seqlen, kv_seqlen)
out_ans, sdpa_vjp_ans = jax.vjp(fn_ans, *args, q_seqlen, kv_seqlen)
dQ_ref, dK_ref, dV_ref, dbias_ref = sdpa_vjp_ref(grad)[:4]
if G != 1:
dK_ref = dK_ref.reshape(B, S, N // G, G, H).sum(axis=3)
dV_ref = dV_ref.reshape(B, S, N // G, G, H).sum(axis=3)

sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl)
if use_vmap and not use_seqlen:
sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0)
_, sdpa_vjp_ans = jax.vjp(sdpa_ans, Q, K, V, bias, causal_mask)
else:
fn_ans = lambda q, k, v, b, m, qs, kvs: sdpa_ans(
q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs
)
_, sdpa_vjp_ans = jax.vjp(fn_ans, Q, K, V, bias, causal_mask, q_seqlen,
kv_seqlen)
dQ_ans, dK_ans, dV_ans, dbias_ans = sdpa_vjp_ans(grad)[:4]

if impl == 'cudnn':
lowered = jax.jit(sdpa_vjp_ans).lower(grad)
hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
self.assertRegex(hlo, r'__cudnn\$fmha.*Backward\(')
# Check if cudnn backend is called.
self.assertTrue(_check_cudnn_backend(sdpa_ans, *args, **kwargs))
self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad))

self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01)
self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01)
self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02)
self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02)
Expand Down

0 comments on commit 8ccc439

Please sign in to comment.