-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A faster flash attention bwd implementation #177
base: main
Are you sure you want to change the base?
Conversation
tonywu95
commented
Jun 22, 2023
•
edited
Loading
edited
- Decompose the bwd kernel into two kernels, one for dq and one for dk,dv.
- Extra parallelism over the sequence length axis.
- On a benchmark, with causal=True, it is close to 6X faster compared to the previous implementation. ~3X faster than XLA bwd pass.
- Decompose the bwd kernel into two kernels, one for dq and one for dk,dv. - Extra parallelism over the sequence length axis. - On a benchmark, it is 4X faster compared to the previous implementation. 2X faster than XLA bwd pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
High level comment: the current backward pass is a fully fused kernel that parallelizes over batch * num heads number of threads.
For attention shapes that have small batch and heads (as is common in language model training) this kernel will underutilize the GPU.
However, there are applications where this kernel is faster than the two kernel variant.
Could you add the two kernel version as a separate backward pass impl, that way the user has the option of selecting the one they want?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also add tests into pallas_test.py?
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), | ||
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), | ||
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), | ||
pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you rename to be (i, j, _)? Same below?
upper_bound = jt.cdiv(seq_len, block_k) | ||
dq = lax.fori_loop(0, upper_bound, inner_loop, dq) | ||
pl.store(dq_ref, (pl.ds(start_q * block_q, block_q), | ||
slice(None)), dq, eviction_policy="evict_last") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need eviction policy here
slice(None)), dv.astype(dv_ref.dtype)) | ||
pl.store(dk_ref, (pl.ds(start_k * block_k, block_k), | ||
slice(None)), dk.astype(dk_ref.dtype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: indentation
@@ -346,6 +450,65 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, | |||
num_warps=num_warps, | |||
num_stages=1, | |||
input_output_aliases={8: 0})(q, k, v, out, do_scaled, l, m, delta, dq) | |||
elif backward_pass_impl == "triton_split": | |||
# We accumulate into dq so we need to initialize it to zeros. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment is not accurate here
@@ -346,6 +450,65 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, | |||
num_warps=num_warps, | |||
num_stages=1, | |||
input_output_aliases={8: 0})(q, k, v, out, do_scaled, l, m, delta, dq) | |||
elif backward_pass_impl == "triton_split": | |||
# We accumulate into dq so we need to initialize it to zeros. | |||
out_shapes_q = jax.ShapeDtypeStruct(q.shape, jnp.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect we don't need dq to be f32 anymore. Could you try q.dtype?
@sharadmv Can this PR be merged? We see a big performance improvement on NVIDIA A100 GPUs with this PR. |
I left some comments. @tonywu95 do you have time to address them? |
Hey @tonywu95, is it ok if we take over this PR and put you as a co-author? We'd love to get it in! |