From 4c642280d478b9228e77c3cab5cb85069b57e840 Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Wed, 28 Aug 2024 04:39:45 +0000 Subject: [PATCH] Use BlockSpecs when possible. --- jax/experimental/pallas/ops/gpu/attention.py | 58 +++++++++----------- 1 file changed, 26 insertions(+), 32 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/attention.py b/jax/experimental/pallas/ops/gpu/attention.py index 1cf8349e7da2..63541e8cb439 100644 --- a/jax/experimental/pallas/ops/gpu/attention.py +++ b/jax/experimental/pallas/ops/gpu/attention.py @@ -41,7 +41,7 @@ def mha_forward_kernel( block_d: int, block_k: int, ): - seq_len = q_ref.shape[0] + seq_len = k_ref.shape[0] start_q = pl.program_id(0) # o is the buffer where we accumulate the output on sram. @@ -55,7 +55,7 @@ def mha_forward_kernel( # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index. # q tile has shape [block_q, block_d], block_d == head_dim. curr_q_slice = pl.dslice(start_q * block_q, block_q) - q = pl.load(q_ref, (curr_q_slice, pl.dslice(None))) + q = q_ref[...] q_segment_ids = ( None if segment_ids_ref is None @@ -123,12 +123,9 @@ def body(start_k, carry): if residual_refs: lse_ref = residual_refs[0] - lse_i = m_i + jnp.log(l_i) - pl.store(lse_ref, (curr_q_slice,), lse_i) + lse_ref[...] = m_i + jnp.log(l_i) # Write output to dram. - o = o.astype(o_ref.dtype) - pl.store(o_ref, (curr_q_slice, pl.dslice(None)), o) - + o_ref[...] = o.astype(o_ref.dtype) def segment_mask( q_segment_ids: jax.Array, @@ -197,7 +194,7 @@ def mha( in_specs = [ pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) ), pl.BlockSpec( (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) @@ -217,7 +214,7 @@ def mha( grid=grid_, in_specs=in_specs, out_specs=pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) ), compiler_params=dict( triton=dict(num_warps=num_warps_, num_stages=num_stages) @@ -268,7 +265,7 @@ def _mha_forward( ] in_specs = [ pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) ), pl.BlockSpec( (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) @@ -288,9 +285,9 @@ def _mha_forward( in_specs=in_specs, out_specs=[ pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) ), - pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)), + pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i)), ], compiler_params=dict( triton=dict(num_warps=num_warps_, num_stages=num_stages) @@ -303,17 +300,14 @@ def _mha_forward( return out, (q, k, v, segment_ids, out, lse) -def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref, *, block_q: int): - pid_m = pl.program_id(0) - - off_m = pl.ds(pid_m * block_q, block_q) +def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref): # load - o = pl.load(out_ref, (off_m, slice(None))).astype(jnp.float32) - do = pl.load(dout_ref, (off_m, slice(None))).astype(jnp.float32) + o = out_ref[...].astype(jnp.float32) + do = dout_ref[...].astype(jnp.float32) # compute delta = jnp.sum(o * do, axis=1) # write-back - pl.store(delta_ref, (off_m,), delta.astype(delta_ref.dtype)) + delta_ref[...] = delta.astype(delta_ref.dtype) @jax.named_scope("preprocess_backward") def _preprocess_backward(out, do, lse, block_q: int, @@ -321,17 +315,17 @@ def _preprocess_backward(out, do, lse, block_q: int, batch_size, seq_len, num_heads, head_dim = out.shape out_shape = jax.ShapeDtypeStruct(lse.shape, lse.dtype) delta = pl.pallas_call( - functools.partial(_preprocess_backward_kernel, block_q=block_q), + _preprocess_backward_kernel, grid=(pl.cdiv(seq_len, block_q), batch_size, num_heads), in_specs=[ pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) ), ], - out_specs=pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)), + out_specs=pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i)), compiler_params=dict(triton=dict(num_warps=4, num_stages=3)), out_shape=out_shape, debug=debug, @@ -431,8 +425,8 @@ def inner_loop_dkdv(start_q, carry): dv, dk = lax.fori_loop( lower_bound, pl.cdiv(seq_len, block_q1), inner_loop_dkdv, (dv, dk) ) - pl.store(dv_ref, (curr_k_slice, slice(None)), dv.astype(dv_ref.dtype)) - pl.store(dk_ref, (curr_k_slice, slice(None)), dk.astype(dk_ref.dtype)) + dv_ref[...] = dv.astype(dv_ref.dtype) + dk_ref[...] = dk.astype(dk_ref.dtype) del dv, dk @@ -495,7 +489,7 @@ def inner_loop_dq(start_k, dq): upper_bound = pl.cdiv(seq_len, block_k2) dq = lax.fori_loop(0, upper_bound, inner_loop_dq, (dq)) - pl.store(dq_ref, (curr_q_slice, slice(None)), dq.astype(dq_ref.dtype)) + dq_ref[...] = dq.astype(dq_ref.dtype) def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, @@ -566,16 +560,16 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, grid=grid, out_specs=[ pl.BlockSpec( - (None, seq_len, None, head_dim), - lambda i, j, _: (i, 0, j, 0), # dq + (None, block_q, None, head_dim), + lambda i, j, k: (i, k, j, 0), # dq ), pl.BlockSpec( - (None, seq_len, None, head_dim), - lambda i, j, _: (i, 0, j, 0), # dk + (None, block_k, None, head_dim), + lambda i, j, k: (i, k, j, 0), # dk ), pl.BlockSpec( - (None, seq_len, None, head_dim), - lambda i, j, _: (i, 0, j, 0), # dv + (None, block_k, None, head_dim), + lambda i, j, k: (i, k, j, 0), # dv ), ], name="mha_backward",