Skip to content

Commit

Permalink
Use BlockSpecs when possible.
Browse files Browse the repository at this point in the history
  • Loading branch information
Rifur13 committed Aug 28, 2024
1 parent 45dcfde commit 4c64228
Showing 1 changed file with 26 additions and 32 deletions.
58 changes: 26 additions & 32 deletions jax/experimental/pallas/ops/gpu/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def mha_forward_kernel(
block_d: int,
block_k: int,
):
seq_len = q_ref.shape[0]
seq_len = k_ref.shape[0]
start_q = pl.program_id(0)

# o is the buffer where we accumulate the output on sram.
Expand All @@ -55,7 +55,7 @@ def mha_forward_kernel(
# read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index.
# q tile has shape [block_q, block_d], block_d == head_dim.
curr_q_slice = pl.dslice(start_q * block_q, block_q)
q = pl.load(q_ref, (curr_q_slice, pl.dslice(None)))
q = q_ref[...]
q_segment_ids = (
None
if segment_ids_ref is None
Expand Down Expand Up @@ -123,12 +123,9 @@ def body(start_k, carry):

if residual_refs:
lse_ref = residual_refs[0]
lse_i = m_i + jnp.log(l_i)
pl.store(lse_ref, (curr_q_slice,), lse_i)
lse_ref[...] = m_i + jnp.log(l_i)
# Write output to dram.
o = o.astype(o_ref.dtype)
pl.store(o_ref, (curr_q_slice, pl.dslice(None)), o)

o_ref[...] = o.astype(o_ref.dtype)

def segment_mask(
q_segment_ids: jax.Array,
Expand Down Expand Up @@ -197,7 +194,7 @@ def mha(

in_specs = [
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)
(None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0)
),
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)
Expand All @@ -217,7 +214,7 @@ def mha(
grid=grid_,
in_specs=in_specs,
out_specs=pl.BlockSpec(
(None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)
(None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0)
),
compiler_params=dict(
triton=dict(num_warps=num_warps_, num_stages=num_stages)
Expand Down Expand Up @@ -268,7 +265,7 @@ def _mha_forward(
]
in_specs = [
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)
(None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0)
),
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)
Expand All @@ -288,9 +285,9 @@ def _mha_forward(
in_specs=in_specs,
out_specs=[
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)
(None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0)
),
pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)),
pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i)),
],
compiler_params=dict(
triton=dict(num_warps=num_warps_, num_stages=num_stages)
Expand All @@ -303,35 +300,32 @@ def _mha_forward(
return out, (q, k, v, segment_ids, out, lse)


def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref, *, block_q: int):
pid_m = pl.program_id(0)

off_m = pl.ds(pid_m * block_q, block_q)
def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref):
# load
o = pl.load(out_ref, (off_m, slice(None))).astype(jnp.float32)
do = pl.load(dout_ref, (off_m, slice(None))).astype(jnp.float32)
o = out_ref[...].astype(jnp.float32)
do = dout_ref[...].astype(jnp.float32)
# compute
delta = jnp.sum(o * do, axis=1)
# write-back
pl.store(delta_ref, (off_m,), delta.astype(delta_ref.dtype))
delta_ref[...] = delta.astype(delta_ref.dtype)

@jax.named_scope("preprocess_backward")
def _preprocess_backward(out, do, lse, block_q: int,
debug: bool, interpret: bool):
batch_size, seq_len, num_heads, head_dim = out.shape
out_shape = jax.ShapeDtypeStruct(lse.shape, lse.dtype)
delta = pl.pallas_call(
functools.partial(_preprocess_backward_kernel, block_q=block_q),
_preprocess_backward_kernel,
grid=(pl.cdiv(seq_len, block_q), batch_size, num_heads),
in_specs=[
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)
(None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0)
),
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)
(None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0)
),
],
out_specs=pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)),
out_specs=pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i)),
compiler_params=dict(triton=dict(num_warps=4, num_stages=3)),
out_shape=out_shape,
debug=debug,
Expand Down Expand Up @@ -431,8 +425,8 @@ def inner_loop_dkdv(start_q, carry):
dv, dk = lax.fori_loop(
lower_bound, pl.cdiv(seq_len, block_q1), inner_loop_dkdv, (dv, dk)
)
pl.store(dv_ref, (curr_k_slice, slice(None)), dv.astype(dv_ref.dtype))
pl.store(dk_ref, (curr_k_slice, slice(None)), dk.astype(dk_ref.dtype))
dv_ref[...] = dv.astype(dv_ref.dtype)
dk_ref[...] = dk.astype(dk_ref.dtype)

del dv, dk

Expand Down Expand Up @@ -495,7 +489,7 @@ def inner_loop_dq(start_k, dq):
upper_bound = pl.cdiv(seq_len, block_k2)

dq = lax.fori_loop(0, upper_bound, inner_loop_dq, (dq))
pl.store(dq_ref, (curr_q_slice, slice(None)), dq.astype(dq_ref.dtype))
dq_ref[...] = dq.astype(dq_ref.dtype)


def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
Expand Down Expand Up @@ -566,16 +560,16 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
grid=grid,
out_specs=[
pl.BlockSpec(
(None, seq_len, None, head_dim),
lambda i, j, _: (i, 0, j, 0), # dq
(None, block_q, None, head_dim),
lambda i, j, k: (i, k, j, 0), # dq
),
pl.BlockSpec(
(None, seq_len, None, head_dim),
lambda i, j, _: (i, 0, j, 0), # dk
(None, block_k, None, head_dim),
lambda i, j, k: (i, k, j, 0), # dk
),
pl.BlockSpec(
(None, seq_len, None, head_dim),
lambda i, j, _: (i, 0, j, 0), # dv
(None, block_k, None, head_dim),
lambda i, j, k: (i, k, j, 0), # dv
),
],
name="mha_backward",
Expand Down

0 comments on commit 4c64228

Please sign in to comment.