From 88b0f4d8a331a949a1b7d7945718d6a1bc9992fa Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Fri, 10 Feb 2023 04:05:18 +0000 Subject: [PATCH] Add option for XLA backward pass in flash attention --- jax_triton/pallas/ops/attention.py | 92 +++++++++++++++++------------- jax_triton/pallas/primitives.py | 8 +-- 2 files changed, 56 insertions(+), 44 deletions(-) diff --git a/jax_triton/pallas/ops/attention.py b/jax_triton/pallas/ops/attention.py index 8c4d2359..8eab39a5 100644 --- a/jax_triton/pallas/ops/attention.py +++ b/jax_triton/pallas/ops/attention.py @@ -103,19 +103,22 @@ def body(i, refs): acc = acc.astype(o_ref.dtype) pl.store(o_ref, (pl.dslice(start_q * block_q, block_q), pl.dslice(None)), acc) -@functools.partial(jax.custom_vjp, nondiff_argnums=[3, 4, 5, 6, 7, 8, 9, 10]) +@functools.partial(jax.custom_vjp, nondiff_argnums=[3, 4, 5, 6, 7, 8, 9, 10, 11]) @functools.partial(jax.jit, static_argnames=["sm_scale", "block_q", "block_k", + "backward_pass_impl", "num_warps", "num_stages", "grid", "interpret", "debug"]) def mha(q, k, v, sm_scale: float = 1.0, block_q: int = 128, block_k: int = 128, + backward_pass_impl: str = "triton", num_warps: Optional[int] = None, num_stages: int = 1, grid=None, interpret: bool = False, debug: bool = False): + del backward_pass_impl batch_size, seq_len, num_heads, head_dim = q.shape block_q = min(block_q, seq_len) block_k = min(block_k, seq_len) @@ -156,8 +159,10 @@ def mha(q, k, v, return out def _mha_forward(q, k, v, sm_scale: float, block_q: int, block_k: int, + backward_pass_impl: str, num_warps: Optional[int], num_stages: int, grid: Any, interpret: bool, debug: bool): + del backward_pass_impl batch_size, seq_len, num_heads, head_dim = q.shape block_q = min(block_q, seq_len) block_k = min(block_k, seq_len) @@ -257,7 +262,7 @@ def mha_backward_kernel( *, sm_scale: float, block_q: int, block_d: int, block_k: int ): - del out_ref, l_ref # Not needed + del out_ref, l_ref # Not needed seq_len = q_ref.shape[0] def outer_loop(start_k, _): @@ -298,53 +303,60 @@ def inner_loop(start_q, refs): slice(None)), dk.astype(dk_ref.dtype)) for_loop(jt.cdiv(seq_len, block_k), outer_loop, ()) -def _mha_backward(sm_scale: float, block_q: int, block_k: int, num_warps: - Optional[int], num_stages: int, grid: Any, interpret: bool, +def _mha_backward(sm_scale: float, block_q: int, block_k: int, + backward_pass_impl: str, num_warps: Optional[int], + num_stages: int, grid: Any, interpret: bool, debug: bool, res, do): del num_warps, num_stages, grid q, k, v, out, l, m = res + batch_size, seq_len, num_heads, head_dim = q.shape block_q = min(block_q, seq_len) block_k = min(block_k, seq_len) do_scaled, delta = _preprocess_backward(out, do, l, block_q, debug, interpret) - # We accumulate into dq so we need to initialize it to zeros. - dq = jnp.zeros(q.shape, jnp.float32) - out_shapes = [ - jax.ShapeDtypeStruct(dq.shape, dq.dtype), - jax.ShapeDtypeStruct(k.shape, k.dtype), - jax.ShapeDtypeStruct(v.shape, v.dtype), - ] + if backward_pass_impl == "xla": + return jax.vjp(mha_reference, q, k, v)[1](do) + elif backward_pass_impl == "triton": + # We accumulate into dq so we need to initialize it to zeros. + dq = jnp.zeros(q.shape, jnp.float32) + out_shapes = [ + jax.ShapeDtypeStruct(dq.shape, dq.dtype), + jax.ShapeDtypeStruct(k.shape, k.dtype), + jax.ShapeDtypeStruct(v.shape, v.dtype), + ] - grid = (batch_size, num_heads) - num_warps = 8 - dq, dk, dv = pl.pallas_call( - functools.partial(mha_backward_kernel, block_q=block_q, block_d=head_dim, - block_k=block_k, sm_scale=sm_scale), - grid=grid, - out_shape=out_shapes, - in_specs=[ - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - ], - out_specs=[ - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - ], - name="mha_backward", - debug=debug, - interpret=interpret, - num_warps=num_warps, - num_stages=1, - input_output_aliases={8: 0})(q, k, v, out, do_scaled, l, m, delta, dq) + grid = (batch_size, num_heads) + num_warps = 8 + dq, dk, dv = pl.pallas_call( + functools.partial(mha_backward_kernel, block_q=block_q, block_d=head_dim, + block_k=block_k, sm_scale=sm_scale), + grid=grid, + out_shape=out_shapes, + in_specs=[ + pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + ], + out_specs=[ + pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + ], + name="mha_backward", + debug=debug, + interpret=interpret, + num_warps=num_warps, + num_stages=1, + input_output_aliases={8: 0})(q, k, v, out, do_scaled, l, m, delta, dq) + else: + raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}") return dq.astype(q.dtype), dk, dv mha.defvjp(_mha_forward, _mha_backward) diff --git a/jax_triton/pallas/primitives.py b/jax_triton/pallas/primitives.py index 534760c5..03ff4ef9 100644 --- a/jax_triton/pallas/primitives.py +++ b/jax_triton/pallas/primitives.py @@ -326,9 +326,9 @@ def _load_pp_rule(eqn, context, settings): idx, *masked_other = tree_util.tree_unflatten(eqn.params["args_tree"], args) idx = _pp_idx(eqn.invars[0].aval, idx, context) lhs = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes) - return [lhs, pp.text(' <- '), state_primitives.pp_ref(pp.concat([ + return pp.concat([lhs, pp.text(' <- '), state_primitives.pp_ref(pp.concat([ pp.text(jax_core.pp_var(x, context)), pp.text('['), idx, pp.text(']') - ]))] + ]))]) jax_core.pp_eqn_rules[load_p] = _load_pp_rule def _load_jvp(primals, tangents, *, args_tree, masked, **params: Any): @@ -400,9 +400,9 @@ def _swap_pp_rule(eqn, context, settings): idx = _pp_idx(eqn.invars[0].aval, idx, context) lhs = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes) if isinstance(y, jax_core.DropVar): - return [state_primitives.pp_ref(pp.concat([ + return pp.concat([state_primitives.pp_ref(pp.concat([ pp.text(jax_core.pp_var(x, context)), pp.text('['), idx, pp.text(']'), - pp.text(" <- "), pp.text(jax_core.pp_var(val, context))]))] + pp.text(" <- "), pp.text(jax_core.pp_var(val, context))]))]) jax_core.pp_eqn_rules[swap_p] = _swap_pp_rule def _swap_jvp(primals, tangents, *, args_tree, masked, **params: Any):