From 8ccc439d4a0ec23bd613f9a200cd1c154295873d Mon Sep 17 00:00:00 2001 From: Kaixi Hou Date: Fri, 30 Aug 2024 10:11:19 -0700 Subject: [PATCH] PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention` Imported from GitHub PR https://github.com/google/jax/pull/23223 While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping. For the new tests, we categorize them as follows: 1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc. 2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations. Additionally, we will no longer maintain separate tests for inference and training. Copybara import of the project: -- dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih : Reduce attn tests Merging this change closes #23223 COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5 PiperOrigin-RevId: 669364738 --- tests/nn_test.py | 173 +++++++++++++++++++---------------------------- 1 file changed, 70 insertions(+), 103 deletions(-) diff --git a/tests/nn_test.py b/tests/nn_test.py index a79cf738714b..3722db42671c 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -45,149 +45,116 @@ def _is_required_cudnn_version_satisfied(): cuda_versions.cudnn_get_version() >= 8904 ) -def _get_causal_mask(T, S): - causal_mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_)) - return causal_mask[jnp.newaxis, jnp.newaxis, :, :] +def _check_cudnn_backend(fn, *args, **kwargs): + lowered = jax.jit(fn).lower(*args, **kwargs) + hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo')) + return '__cudnn$fmha' in hlo @jtu.with_config(jax_legacy_prng_key="allow", jax_numpy_dtype_promotion="standard") class NNFunctionsTest(jtu.JaxTestCase): @parameterized.product( - dtype=[jnp.float32, jnp.bfloat16, jnp.float16], - use_bias=[False, True], - causal_mode=[None, 'attr', 'mask'], + dtype=[jnp.bfloat16, jnp.float16], group_num=[1, 2, 4], use_vmap=[False, True], - use_seqlen=[False, True], - impl=['xla', 'cudnn'], + impl=['cudnn', 'xla'], ) - def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode, - group_num, use_vmap, use_seqlen, impl): + def testDotProductAttention(self, dtype, group_num, use_vmap, impl): if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") if impl == 'cudnn' and dtype == jnp.float32: raise unittest.SkipTest("cuDNN only supports fp16 or bf16.") - if use_vmap and use_seqlen: - raise unittest.SkipTest("vmap cannot be used together with variable " - "seqence lengths") - sdpa = nn.dot_product_attention B, S, T, N, H, G = 2, 128, 128, 4, 32, group_num - keys = random.split(random.PRNGKey(0), 4) + keys = random.split(random.PRNGKey(0), 5) Q = random.normal(keys[0], (B, T, N, H), dtype) K = random.normal(keys[1], (B, S, N // G, H), dtype) V = random.normal(keys[2], (B, S, N // G, H), dtype) - if use_bias: - bias = random.normal(keys[3], (1, N, T, S), dtype) - else: - bias = None - if use_seqlen: - q_seqlen = jnp.array([T // 2, T // 4], dtype=jnp.int32) - kv_seqlen = jnp.array([S // 4, S // 2], dtype=jnp.int32) - else: - q_seqlen = None - kv_seqlen = None + grad = random.normal(keys[3], (B, T, N, H), dtype) + bias, mask = None, None - is_causal = causal_mode == 'attr' - causal_mask = _get_causal_mask(T, S) if causal_mode == 'mask' else None + sdpa = nn.dot_product_attention + sdpa_ref = partial(sdpa, implementation=None) + sdpa_ans = partial(sdpa, implementation=impl) + if use_vmap: + sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0) - sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None) - sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl) + # For testing purposes, we call the non-GQA version without vmap in the + # reference code + K_ref = jnp.repeat(K, G, axis=2) + V_ref = jnp.repeat(V, G, axis=2) + out_ref, sdpa_vjp_ref = jax.vjp(sdpa_ref, Q, K_ref, V_ref, bias, mask) + out_ans, sdpa_vjp_ans = jax.vjp(sdpa_ans, Q, K, V, bias, mask) + + dQ_ref, dK_ref, dV_ref = sdpa_vjp_ref(grad)[:3] + dQ_ans, dK_ans, dV_ans = sdpa_vjp_ans(grad)[:3] + dK_ref = dK_ref.reshape(B, S, N // G, G, H).sum(axis=3) + dV_ref = dV_ref.reshape(B, S, N // G, G, H).sum(axis=3) if impl == 'cudnn': - lowered = jax.jit(sdpa_ans).lower(Q, K, V, bias, causal_mask, - query_seq_lengths=q_seqlen, - key_value_seq_lengths=kv_seqlen) - hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo')) - self.assertIn('__cudnn$fmha', hlo) - - K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K - V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V - out_ref = sdpa_ref(Q, K_ref, V_ref, bias, causal_mask, - query_seq_lengths=q_seqlen, - key_value_seq_lengths=kv_seqlen) - if use_vmap: - sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0) + self.assertTrue(_check_cudnn_backend(sdpa_ans, Q, K, V, bias, mask)) + self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad)) - out_ans = sdpa_ans(Q, K, V, bias, causal_mask, - query_seq_lengths=q_seqlen, - key_value_seq_lengths=kv_seqlen) self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01) + self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01) + self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02) + self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02) @parameterized.product( - dtype=[jnp.float32, jnp.bfloat16, jnp.float16], - use_bias=[False, True], - causal_mode=[None, 'attr', 'mask'], - group_num=[1, 2, 4], - use_vmap=[False, True], - use_seqlen=[False, True], - impl=['xla', 'cudnn'], + mask_mode=['bias', 'causal', 'padding', 'custom', ('causal', 'padding'), + ('custom', 'padding'), ('bias', 'causal')], ) - def testDotProductAttentionTrain(self, dtype, use_bias, causal_mode, - group_num, use_vmap, use_seqlen, impl): - if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(): + def testDotProductAttentionMask(self, mask_mode): + if not _is_required_cudnn_version_satisfied(): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") - if impl == 'cudnn' and dtype == jnp.float32: - raise unittest.SkipTest("cuDNN only supports fp16 or bf16.") - if use_vmap and use_seqlen: - raise unittest.SkipTest("vmap cannot be used together with variable " - "seqence lengths") - if use_seqlen and use_bias and impl == 'cudnn': - raise unittest.SkipTest("cudnn has limited support for dbias when using " - "variable seqence lengths") + if isinstance(mask_mode, str): + mask_mode = (mask_mode,) - sdpa = nn.dot_product_attention - B, S, T, N, H, G = 2, 128, 128, 4, 32, group_num - keys = random.split(random.PRNGKey(0), 5) + dtype = jnp.bfloat16 + B, S, T, N, H = 2, 128, 128, 4, 32 + keys = random.split(random.PRNGKey(0), 4) Q = random.normal(keys[0], (B, T, N, H), dtype) - K = random.normal(keys[1], (B, S, N // G, H), dtype) - V = random.normal(keys[2], (B, S, N // G, H), dtype) + K = random.normal(keys[1], (B, S, N, H), dtype) + V = random.normal(keys[2], (B, S, N, H), dtype) grad = random.normal(keys[3], (B, T, N, H), dtype) - if use_bias: - bias = random.normal(keys[4], (1, N, T, S), dtype) - else: - bias = None - if use_seqlen: + bias, mask = None, None + q_seqlen, kv_seqlen = None, None + + is_causal = 'causal' in mask_mode + if 'padding' in mask_mode: q_seqlen = jnp.array([T // 2, T // 4], dtype=jnp.int32) kv_seqlen = jnp.array([S // 4, S // 2], dtype=jnp.int32) - else: - q_seqlen = None - kv_seqlen = None - - is_causal = causal_mode == 'attr' - causal_mask = _get_causal_mask(T, S) if causal_mode == 'mask' else None + if 'custom' in mask_mode: + # Use a generated causal mask as the custom mask. + custom_mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_)) + mask = custom_mask[None, None, :, :] + if 'bias' in mask_mode: + bias = random.normal(keys[4], (1, N, T, S), dtype) - K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K - V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V + sdpa = nn.dot_product_attention sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None) - # Convert the keyword arguments to positional ones. + sdpa_ans = partial(sdpa, is_causal=is_causal, implementation='cudnn') + + args = (Q, K, V, bias, mask) + kwargs = {'query_seq_lengths': q_seqlen, 'key_value_seq_lengths': kv_seqlen} + + # Convert the kargs to positional args for the jax.vjp. fn_ref = lambda q, k, v, b, m, qs, kvs: sdpa_ref( - q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs + q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs, + ) + fn_ans = lambda q, k, v, b, m, qs, kvs: sdpa_ans( + q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs, ) - _, sdpa_vjp_ref = jax.vjp(fn_ref, Q, K_ref, V_ref, bias, causal_mask, - q_seqlen, kv_seqlen) + out_ref, sdpa_vjp_ref = jax.vjp(fn_ref, *args, q_seqlen, kv_seqlen) + out_ans, sdpa_vjp_ans = jax.vjp(fn_ans, *args, q_seqlen, kv_seqlen) dQ_ref, dK_ref, dV_ref, dbias_ref = sdpa_vjp_ref(grad)[:4] - if G != 1: - dK_ref = dK_ref.reshape(B, S, N // G, G, H).sum(axis=3) - dV_ref = dV_ref.reshape(B, S, N // G, G, H).sum(axis=3) - - sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl) - if use_vmap and not use_seqlen: - sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0) - _, sdpa_vjp_ans = jax.vjp(sdpa_ans, Q, K, V, bias, causal_mask) - else: - fn_ans = lambda q, k, v, b, m, qs, kvs: sdpa_ans( - q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs - ) - _, sdpa_vjp_ans = jax.vjp(fn_ans, Q, K, V, bias, causal_mask, q_seqlen, - kv_seqlen) dQ_ans, dK_ans, dV_ans, dbias_ans = sdpa_vjp_ans(grad)[:4] - if impl == 'cudnn': - lowered = jax.jit(sdpa_vjp_ans).lower(grad) - hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo')) - self.assertRegex(hlo, r'__cudnn\$fmha.*Backward\(') + # Check if cudnn backend is called. + self.assertTrue(_check_cudnn_backend(sdpa_ans, *args, **kwargs)) + self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad)) + self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01) self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01) self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02) self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02)