Skip to content

Commit

Permalink
Add option for XLA backward pass in flash attention
Browse files Browse the repository at this point in the history
  • Loading branch information
sharadmv committed Feb 10, 2023
1 parent e3a1931 commit 88b0f4d
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 44 deletions.
92 changes: 52 additions & 40 deletions jax_triton/pallas/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,22 @@ def body(i, refs):
acc = acc.astype(o_ref.dtype)
pl.store(o_ref, (pl.dslice(start_q * block_q, block_q), pl.dslice(None)), acc)

@functools.partial(jax.custom_vjp, nondiff_argnums=[3, 4, 5, 6, 7, 8, 9, 10])
@functools.partial(jax.custom_vjp, nondiff_argnums=[3, 4, 5, 6, 7, 8, 9, 10, 11])
@functools.partial(jax.jit, static_argnames=["sm_scale", "block_q", "block_k",
"backward_pass_impl",
"num_warps", "num_stages", "grid",
"interpret", "debug"])
def mha(q, k, v,
sm_scale: float = 1.0,
block_q: int = 128,
block_k: int = 128,
backward_pass_impl: str = "triton",
num_warps: Optional[int] = None,
num_stages: int = 1,
grid=None,
interpret: bool = False,
debug: bool = False):
del backward_pass_impl
batch_size, seq_len, num_heads, head_dim = q.shape
block_q = min(block_q, seq_len)
block_k = min(block_k, seq_len)
Expand Down Expand Up @@ -156,8 +159,10 @@ def mha(q, k, v,
return out

def _mha_forward(q, k, v, sm_scale: float, block_q: int, block_k: int,
backward_pass_impl: str,
num_warps: Optional[int], num_stages: int, grid: Any,
interpret: bool, debug: bool):
del backward_pass_impl
batch_size, seq_len, num_heads, head_dim = q.shape
block_q = min(block_q, seq_len)
block_k = min(block_k, seq_len)
Expand Down Expand Up @@ -257,7 +262,7 @@ def mha_backward_kernel(
*, sm_scale: float,
block_q: int, block_d: int, block_k: int
):
del out_ref, l_ref # Not needed
del out_ref, l_ref # Not needed
seq_len = q_ref.shape[0]

def outer_loop(start_k, _):
Expand Down Expand Up @@ -298,53 +303,60 @@ def inner_loop(start_q, refs):
slice(None)), dk.astype(dk_ref.dtype))
for_loop(jt.cdiv(seq_len, block_k), outer_loop, ())

def _mha_backward(sm_scale: float, block_q: int, block_k: int, num_warps:
Optional[int], num_stages: int, grid: Any, interpret: bool,
def _mha_backward(sm_scale: float, block_q: int, block_k: int,
backward_pass_impl: str, num_warps: Optional[int],
num_stages: int, grid: Any, interpret: bool,
debug: bool, res, do):
del num_warps, num_stages, grid
q, k, v, out, l, m = res

batch_size, seq_len, num_heads, head_dim = q.shape
block_q = min(block_q, seq_len)
block_k = min(block_k, seq_len)
do_scaled, delta = _preprocess_backward(out, do, l, block_q, debug, interpret)
# We accumulate into dq so we need to initialize it to zeros.
dq = jnp.zeros(q.shape, jnp.float32)

out_shapes = [
jax.ShapeDtypeStruct(dq.shape, dq.dtype),
jax.ShapeDtypeStruct(k.shape, k.dtype),
jax.ShapeDtypeStruct(v.shape, v.dtype),
]
if backward_pass_impl == "xla":
return jax.vjp(mha_reference, q, k, v)[1](do)
elif backward_pass_impl == "triton":
# We accumulate into dq so we need to initialize it to zeros.
dq = jnp.zeros(q.shape, jnp.float32)
out_shapes = [
jax.ShapeDtypeStruct(dq.shape, dq.dtype),
jax.ShapeDtypeStruct(k.shape, k.dtype),
jax.ShapeDtypeStruct(v.shape, v.dtype),
]

grid = (batch_size, num_heads)
num_warps = 8
dq, dk, dv = pl.pallas_call(
functools.partial(mha_backward_kernel, block_q=block_q, block_d=head_dim,
block_k=block_k, sm_scale=sm_scale),
grid=grid,
out_shape=out_shapes,
in_specs=[
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
],
out_specs=[
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
],
name="mha_backward",
debug=debug,
interpret=interpret,
num_warps=num_warps,
num_stages=1,
input_output_aliases={8: 0})(q, k, v, out, do_scaled, l, m, delta, dq)
grid = (batch_size, num_heads)
num_warps = 8
dq, dk, dv = pl.pallas_call(
functools.partial(mha_backward_kernel, block_q=block_q, block_d=head_dim,
block_k=block_k, sm_scale=sm_scale),
grid=grid,
out_shape=out_shapes,
in_specs=[
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
],
out_specs=[
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
],
name="mha_backward",
debug=debug,
interpret=interpret,
num_warps=num_warps,
num_stages=1,
input_output_aliases={8: 0})(q, k, v, out, do_scaled, l, m, delta, dq)
else:
raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}")
return dq.astype(q.dtype), dk, dv
mha.defvjp(_mha_forward, _mha_backward)

Expand Down
8 changes: 4 additions & 4 deletions jax_triton/pallas/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,9 @@ def _load_pp_rule(eqn, context, settings):
idx, *masked_other = tree_util.tree_unflatten(eqn.params["args_tree"], args)
idx = _pp_idx(eqn.invars[0].aval, idx, context)
lhs = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
return [lhs, pp.text(' <- '), state_primitives.pp_ref(pp.concat([
return pp.concat([lhs, pp.text(' <- '), state_primitives.pp_ref(pp.concat([
pp.text(jax_core.pp_var(x, context)), pp.text('['), idx, pp.text(']')
]))]
]))])
jax_core.pp_eqn_rules[load_p] = _load_pp_rule

def _load_jvp(primals, tangents, *, args_tree, masked, **params: Any):
Expand Down Expand Up @@ -400,9 +400,9 @@ def _swap_pp_rule(eqn, context, settings):
idx = _pp_idx(eqn.invars[0].aval, idx, context)
lhs = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
if isinstance(y, jax_core.DropVar):
return [state_primitives.pp_ref(pp.concat([
return pp.concat([state_primitives.pp_ref(pp.concat([
pp.text(jax_core.pp_var(x, context)), pp.text('['), idx, pp.text(']'),
pp.text(" <- "), pp.text(jax_core.pp_var(val, context))]))]
pp.text(" <- "), pp.text(jax_core.pp_var(val, context))]))])
jax_core.pp_eqn_rules[swap_p] = _swap_pp_rule

def _swap_jvp(primals, tangents, *, args_tree, masked, **params: Any):
Expand Down

0 comments on commit 88b0f4d

Please sign in to comment.