Skip to content

Commit

Permalink
Reduce attn tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kaixih committed Aug 26, 2024
1 parent be59f6e commit dd2ca19
Showing 1 changed file with 70 additions and 103 deletions.
173 changes: 70 additions & 103 deletions tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,149 +45,116 @@ def _is_required_cudnn_version_satisfied():
cuda_versions.cudnn_get_version() >= 8904
)

def _get_causal_mask(T, S):
causal_mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_))
return causal_mask[jnp.newaxis, jnp.newaxis, :, :]
def _check_cudnn_backend(fn, *args, **kwargs):
lowered = jax.jit(fn).lower(*args, **kwargs)
hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
return '__cudnn$fmha' in hlo

@jtu.with_config(jax_legacy_prng_key="allow",
jax_numpy_dtype_promotion="standard")
class NNFunctionsTest(jtu.JaxTestCase):
@parameterized.product(
dtype=[jnp.float32, jnp.bfloat16, jnp.float16],
use_bias=[False, True],
causal_mode=[None, 'attr', 'mask'],
dtype=[jnp.bfloat16, jnp.float16],
group_num=[1, 2, 4],
use_vmap=[False, True],
use_seqlen=[False, True],
impl=['xla', 'cudnn'],
impl=['cudnn', 'xla'],
)
def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode,
group_num, use_vmap, use_seqlen, impl):
def testDotProductAttention(self, dtype, group_num, use_vmap, impl):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied():
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
if impl == 'cudnn' and dtype == jnp.float32:
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.")
if use_vmap and use_seqlen:
raise unittest.SkipTest("vmap cannot be used together with variable "
"seqence lengths")

sdpa = nn.dot_product_attention
B, S, T, N, H, G = 2, 128, 128, 4, 32, group_num
keys = random.split(random.PRNGKey(0), 4)
keys = random.split(random.PRNGKey(0), 5)
Q = random.normal(keys[0], (B, T, N, H), dtype)
K = random.normal(keys[1], (B, S, N // G, H), dtype)
V = random.normal(keys[2], (B, S, N // G, H), dtype)
if use_bias:
bias = random.normal(keys[3], (1, N, T, S), dtype)
else:
bias = None
if use_seqlen:
q_seqlen = jnp.array([T // 2, T // 4], dtype=jnp.int32)
kv_seqlen = jnp.array([S // 4, S // 2], dtype=jnp.int32)
else:
q_seqlen = None
kv_seqlen = None
grad = random.normal(keys[3], (B, T, N, H), dtype)
bias, mask = None, None

is_causal = causal_mode == 'attr'
causal_mask = _get_causal_mask(T, S) if causal_mode == 'mask' else None
sdpa = nn.dot_product_attention
sdpa_ref = partial(sdpa, implementation=None)
sdpa_ans = partial(sdpa, implementation=impl)
if use_vmap:
sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0)

sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None)
sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl)
# For testing purposes, we call the non-GQA version without vmap in the
# reference code
K_ref = jnp.repeat(K, G, axis=2)
V_ref = jnp.repeat(V, G, axis=2)
out_ref, sdpa_vjp_ref = jax.vjp(sdpa_ref, Q, K_ref, V_ref, bias, mask)
out_ans, sdpa_vjp_ans = jax.vjp(sdpa_ans, Q, K, V, bias, mask)

dQ_ref, dK_ref, dV_ref = sdpa_vjp_ref(grad)[:3]
dQ_ans, dK_ans, dV_ans = sdpa_vjp_ans(grad)[:3]
dK_ref = dK_ref.reshape(B, S, N // G, G, H).sum(axis=3)
dV_ref = dV_ref.reshape(B, S, N // G, G, H).sum(axis=3)

if impl == 'cudnn':
lowered = jax.jit(sdpa_ans).lower(Q, K, V, bias, causal_mask,
query_seq_lengths=q_seqlen,
key_value_seq_lengths=kv_seqlen)
hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
self.assertIn('__cudnn$fmha', hlo)

K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K
V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V
out_ref = sdpa_ref(Q, K_ref, V_ref, bias, causal_mask,
query_seq_lengths=q_seqlen,
key_value_seq_lengths=kv_seqlen)
if use_vmap:
sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0)
self.assertTrue(_check_cudnn_backend(sdpa_ans, Q, K, V, bias, mask))
self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad))

out_ans = sdpa_ans(Q, K, V, bias, causal_mask,
query_seq_lengths=q_seqlen,
key_value_seq_lengths=kv_seqlen)
self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01)
self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01)
self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02)
self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02)

@parameterized.product(
dtype=[jnp.float32, jnp.bfloat16, jnp.float16],
use_bias=[False, True],
causal_mode=[None, 'attr', 'mask'],
group_num=[1, 2, 4],
use_vmap=[False, True],
use_seqlen=[False, True],
impl=['xla', 'cudnn'],
mask_mode=['bias', 'causal', 'padding', 'custom', ('causal', 'padding'),
('custom', 'padding'), ('bias', 'causal')],
)
def testDotProductAttentionTrain(self, dtype, use_bias, causal_mode,
group_num, use_vmap, use_seqlen, impl):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied():
def testDotProductAttentionMask(self, mask_mode):
if not _is_required_cudnn_version_satisfied():
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
if impl == 'cudnn' and dtype == jnp.float32:
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.")
if use_vmap and use_seqlen:
raise unittest.SkipTest("vmap cannot be used together with variable "
"seqence lengths")
if use_seqlen and use_bias and impl == 'cudnn':
raise unittest.SkipTest("cudnn has limited support for dbias when using "
"variable seqence lengths")
if isinstance(mask_mode, str):
mask_mode = (mask_mode,)

sdpa = nn.dot_product_attention
B, S, T, N, H, G = 2, 128, 128, 4, 32, group_num
keys = random.split(random.PRNGKey(0), 5)
dtype = jnp.bfloat16
B, S, T, N, H = 2, 128, 128, 4, 32
keys = random.split(random.PRNGKey(0), 4)
Q = random.normal(keys[0], (B, T, N, H), dtype)
K = random.normal(keys[1], (B, S, N // G, H), dtype)
V = random.normal(keys[2], (B, S, N // G, H), dtype)
K = random.normal(keys[1], (B, S, N, H), dtype)
V = random.normal(keys[2], (B, S, N, H), dtype)
grad = random.normal(keys[3], (B, T, N, H), dtype)
if use_bias:
bias = random.normal(keys[4], (1, N, T, S), dtype)
else:
bias = None
if use_seqlen:
bias, mask = None, None
q_seqlen, kv_seqlen = None, None

is_causal = 'causal' in mask_mode
if 'padding' in mask_mode:
q_seqlen = jnp.array([T // 2, T // 4], dtype=jnp.int32)
kv_seqlen = jnp.array([S // 4, S // 2], dtype=jnp.int32)
else:
q_seqlen = None
kv_seqlen = None

is_causal = causal_mode == 'attr'
causal_mask = _get_causal_mask(T, S) if causal_mode == 'mask' else None
if 'custom' in mask_mode:
# Use a generated causal mask as the custom mask.
custom_mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_))
mask = custom_mask[None, None, :, :]
if 'bias' in mask_mode:
bias = random.normal(keys[4], (1, N, T, S), dtype)

K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K
V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V
sdpa = nn.dot_product_attention
sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None)
# Convert the keyword arguments to positional ones.
sdpa_ans = partial(sdpa, is_causal=is_causal, implementation='cudnn')

args = (Q, K, V, bias, mask)
kwargs = {'query_seq_lengths': q_seqlen, 'key_value_seq_lengths': kv_seqlen}

# Convert the kargs to positional args for the jax.vjp.
fn_ref = lambda q, k, v, b, m, qs, kvs: sdpa_ref(
q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs
q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs,
)
fn_ans = lambda q, k, v, b, m, qs, kvs: sdpa_ans(
q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs,
)
_, sdpa_vjp_ref = jax.vjp(fn_ref, Q, K_ref, V_ref, bias, causal_mask,
q_seqlen, kv_seqlen)
out_ref, sdpa_vjp_ref = jax.vjp(fn_ref, *args, q_seqlen, kv_seqlen)
out_ans, sdpa_vjp_ans = jax.vjp(fn_ans, *args, q_seqlen, kv_seqlen)
dQ_ref, dK_ref, dV_ref, dbias_ref = sdpa_vjp_ref(grad)[:4]
if G != 1:
dK_ref = dK_ref.reshape(B, S, N // G, G, H).sum(axis=3)
dV_ref = dV_ref.reshape(B, S, N // G, G, H).sum(axis=3)

sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl)
if use_vmap and not use_seqlen:
sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0)
_, sdpa_vjp_ans = jax.vjp(sdpa_ans, Q, K, V, bias, causal_mask)
else:
fn_ans = lambda q, k, v, b, m, qs, kvs: sdpa_ans(
q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs
)
_, sdpa_vjp_ans = jax.vjp(fn_ans, Q, K, V, bias, causal_mask, q_seqlen,
kv_seqlen)
dQ_ans, dK_ans, dV_ans, dbias_ans = sdpa_vjp_ans(grad)[:4]

if impl == 'cudnn':
lowered = jax.jit(sdpa_vjp_ans).lower(grad)
hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
self.assertRegex(hlo, r'__cudnn\$fmha.*Backward\(')
# Check if cudnn backend is called.
self.assertTrue(_check_cudnn_backend(sdpa_ans, *args, **kwargs))
self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad))

self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01)
self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01)
self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02)
self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02)
Expand Down

0 comments on commit dd2ca19

Please sign in to comment.