diff --git a/tests/nn_test.py b/tests/nn_test.py index a79cf738714b..3722db42671c 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -45,149 +45,116 @@ def _is_required_cudnn_version_satisfied(): cuda_versions.cudnn_get_version() >= 8904 ) -def _get_causal_mask(T, S): - causal_mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_)) - return causal_mask[jnp.newaxis, jnp.newaxis, :, :] +def _check_cudnn_backend(fn, *args, **kwargs): + lowered = jax.jit(fn).lower(*args, **kwargs) + hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo')) + return '__cudnn$fmha' in hlo @jtu.with_config(jax_legacy_prng_key="allow", jax_numpy_dtype_promotion="standard") class NNFunctionsTest(jtu.JaxTestCase): @parameterized.product( - dtype=[jnp.float32, jnp.bfloat16, jnp.float16], - use_bias=[False, True], - causal_mode=[None, 'attr', 'mask'], + dtype=[jnp.bfloat16, jnp.float16], group_num=[1, 2, 4], use_vmap=[False, True], - use_seqlen=[False, True], - impl=['xla', 'cudnn'], + impl=['cudnn', 'xla'], ) - def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode, - group_num, use_vmap, use_seqlen, impl): + def testDotProductAttention(self, dtype, group_num, use_vmap, impl): if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") if impl == 'cudnn' and dtype == jnp.float32: raise unittest.SkipTest("cuDNN only supports fp16 or bf16.") - if use_vmap and use_seqlen: - raise unittest.SkipTest("vmap cannot be used together with variable " - "seqence lengths") - sdpa = nn.dot_product_attention B, S, T, N, H, G = 2, 128, 128, 4, 32, group_num - keys = random.split(random.PRNGKey(0), 4) + keys = random.split(random.PRNGKey(0), 5) Q = random.normal(keys[0], (B, T, N, H), dtype) K = random.normal(keys[1], (B, S, N // G, H), dtype) V = random.normal(keys[2], (B, S, N // G, H), dtype) - if use_bias: - bias = random.normal(keys[3], (1, N, T, S), dtype) - else: - bias = None - if use_seqlen: - q_seqlen = jnp.array([T // 2, T // 4], dtype=jnp.int32) - kv_seqlen = jnp.array([S // 4, S // 2], dtype=jnp.int32) - else: - q_seqlen = None - kv_seqlen = None + grad = random.normal(keys[3], (B, T, N, H), dtype) + bias, mask = None, None - is_causal = causal_mode == 'attr' - causal_mask = _get_causal_mask(T, S) if causal_mode == 'mask' else None + sdpa = nn.dot_product_attention + sdpa_ref = partial(sdpa, implementation=None) + sdpa_ans = partial(sdpa, implementation=impl) + if use_vmap: + sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0) - sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None) - sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl) + # For testing purposes, we call the non-GQA version without vmap in the + # reference code + K_ref = jnp.repeat(K, G, axis=2) + V_ref = jnp.repeat(V, G, axis=2) + out_ref, sdpa_vjp_ref = jax.vjp(sdpa_ref, Q, K_ref, V_ref, bias, mask) + out_ans, sdpa_vjp_ans = jax.vjp(sdpa_ans, Q, K, V, bias, mask) + + dQ_ref, dK_ref, dV_ref = sdpa_vjp_ref(grad)[:3] + dQ_ans, dK_ans, dV_ans = sdpa_vjp_ans(grad)[:3] + dK_ref = dK_ref.reshape(B, S, N // G, G, H).sum(axis=3) + dV_ref = dV_ref.reshape(B, S, N // G, G, H).sum(axis=3) if impl == 'cudnn': - lowered = jax.jit(sdpa_ans).lower(Q, K, V, bias, causal_mask, - query_seq_lengths=q_seqlen, - key_value_seq_lengths=kv_seqlen) - hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo')) - self.assertIn('__cudnn$fmha', hlo) - - K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K - V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V - out_ref = sdpa_ref(Q, K_ref, V_ref, bias, causal_mask, - query_seq_lengths=q_seqlen, - key_value_seq_lengths=kv_seqlen) - if use_vmap: - sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0) + self.assertTrue(_check_cudnn_backend(sdpa_ans, Q, K, V, bias, mask)) + self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad)) - out_ans = sdpa_ans(Q, K, V, bias, causal_mask, - query_seq_lengths=q_seqlen, - key_value_seq_lengths=kv_seqlen) self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01) + self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01) + self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02) + self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02) @parameterized.product( - dtype=[jnp.float32, jnp.bfloat16, jnp.float16], - use_bias=[False, True], - causal_mode=[None, 'attr', 'mask'], - group_num=[1, 2, 4], - use_vmap=[False, True], - use_seqlen=[False, True], - impl=['xla', 'cudnn'], + mask_mode=['bias', 'causal', 'padding', 'custom', ('causal', 'padding'), + ('custom', 'padding'), ('bias', 'causal')], ) - def testDotProductAttentionTrain(self, dtype, use_bias, causal_mode, - group_num, use_vmap, use_seqlen, impl): - if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(): + def testDotProductAttentionMask(self, mask_mode): + if not _is_required_cudnn_version_satisfied(): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") - if impl == 'cudnn' and dtype == jnp.float32: - raise unittest.SkipTest("cuDNN only supports fp16 or bf16.") - if use_vmap and use_seqlen: - raise unittest.SkipTest("vmap cannot be used together with variable " - "seqence lengths") - if use_seqlen and use_bias and impl == 'cudnn': - raise unittest.SkipTest("cudnn has limited support for dbias when using " - "variable seqence lengths") + if isinstance(mask_mode, str): + mask_mode = (mask_mode,) - sdpa = nn.dot_product_attention - B, S, T, N, H, G = 2, 128, 128, 4, 32, group_num - keys = random.split(random.PRNGKey(0), 5) + dtype = jnp.bfloat16 + B, S, T, N, H = 2, 128, 128, 4, 32 + keys = random.split(random.PRNGKey(0), 4) Q = random.normal(keys[0], (B, T, N, H), dtype) - K = random.normal(keys[1], (B, S, N // G, H), dtype) - V = random.normal(keys[2], (B, S, N // G, H), dtype) + K = random.normal(keys[1], (B, S, N, H), dtype) + V = random.normal(keys[2], (B, S, N, H), dtype) grad = random.normal(keys[3], (B, T, N, H), dtype) - if use_bias: - bias = random.normal(keys[4], (1, N, T, S), dtype) - else: - bias = None - if use_seqlen: + bias, mask = None, None + q_seqlen, kv_seqlen = None, None + + is_causal = 'causal' in mask_mode + if 'padding' in mask_mode: q_seqlen = jnp.array([T // 2, T // 4], dtype=jnp.int32) kv_seqlen = jnp.array([S // 4, S // 2], dtype=jnp.int32) - else: - q_seqlen = None - kv_seqlen = None - - is_causal = causal_mode == 'attr' - causal_mask = _get_causal_mask(T, S) if causal_mode == 'mask' else None + if 'custom' in mask_mode: + # Use a generated causal mask as the custom mask. + custom_mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_)) + mask = custom_mask[None, None, :, :] + if 'bias' in mask_mode: + bias = random.normal(keys[4], (1, N, T, S), dtype) - K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K - V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V + sdpa = nn.dot_product_attention sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None) - # Convert the keyword arguments to positional ones. + sdpa_ans = partial(sdpa, is_causal=is_causal, implementation='cudnn') + + args = (Q, K, V, bias, mask) + kwargs = {'query_seq_lengths': q_seqlen, 'key_value_seq_lengths': kv_seqlen} + + # Convert the kargs to positional args for the jax.vjp. fn_ref = lambda q, k, v, b, m, qs, kvs: sdpa_ref( - q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs + q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs, + ) + fn_ans = lambda q, k, v, b, m, qs, kvs: sdpa_ans( + q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs, ) - _, sdpa_vjp_ref = jax.vjp(fn_ref, Q, K_ref, V_ref, bias, causal_mask, - q_seqlen, kv_seqlen) + out_ref, sdpa_vjp_ref = jax.vjp(fn_ref, *args, q_seqlen, kv_seqlen) + out_ans, sdpa_vjp_ans = jax.vjp(fn_ans, *args, q_seqlen, kv_seqlen) dQ_ref, dK_ref, dV_ref, dbias_ref = sdpa_vjp_ref(grad)[:4] - if G != 1: - dK_ref = dK_ref.reshape(B, S, N // G, G, H).sum(axis=3) - dV_ref = dV_ref.reshape(B, S, N // G, G, H).sum(axis=3) - - sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl) - if use_vmap and not use_seqlen: - sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0) - _, sdpa_vjp_ans = jax.vjp(sdpa_ans, Q, K, V, bias, causal_mask) - else: - fn_ans = lambda q, k, v, b, m, qs, kvs: sdpa_ans( - q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs - ) - _, sdpa_vjp_ans = jax.vjp(fn_ans, Q, K, V, bias, causal_mask, q_seqlen, - kv_seqlen) dQ_ans, dK_ans, dV_ans, dbias_ans = sdpa_vjp_ans(grad)[:4] - if impl == 'cudnn': - lowered = jax.jit(sdpa_vjp_ans).lower(grad) - hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo')) - self.assertRegex(hlo, r'__cudnn\$fmha.*Backward\(') + # Check if cudnn backend is called. + self.assertTrue(_check_cudnn_backend(sdpa_ans, *args, **kwargs)) + self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad)) + self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01) self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01) self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02) self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02)