-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX] Support Ring Attention (Context Parallelism) #1059
[JAX] Support Ring Attention (Context Parallelism) #1059
Conversation
a42f6a2
to
7a147c8
Compare
f50f6f0
to
9fb3dc3
Compare
9fb3dc3
to
c67cc11
Compare
/te-ci jax L1 |
|
||
@staticmethod | ||
@cache | ||
def use_scanloop(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit ugly but a necessity of using scan loop which is preferred as it gives XLA more control over unrolling. We won't prevent not using scan but it currently warns the user to update their XLA flags. NVTE_FUSED_RING_ATTENTION_USE_SCAN
if undefined will default to 1.
Hopefully --xla_experimental_ignore_channel_id
becomes default in XLA:GPU at some point.
|
||
# Combine KV tensors if separate for better permute scheduling and performance. | ||
# Eventually XLA should perform this automatically. | ||
kv = helper.stack_kv(k, v) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
XLA:GPU does a poor job scheduling 2+ permutes with overlap so we manually combine K and V into our combined KV format. There is a plan to add functionality into XLA to do this fusion automatically.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the impressive work! I’ll need a few days to complete my review of the partition details within the ring attention.
e35fd0d
to
51922dc
Compare
/te-ci jax L1 |
Signed-off-by: Michael Goldfarb <mgoldfarb@nvidia.com> Signed-off-by: Ming Huang <mingh@nvidia.com>
d04ef7f
to
eba6d4e
Compare
/te-ci jax L1 |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
return lax.cond((idx <= cp_rank), no_mask_compute, skip_compute) | ||
|
||
output_per_step, softmax_aux_per_step = lax.cond( | ||
idx == 0, causal_mask_compute, jax_cond_wrap |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you clarify why causal computation is applied to idx == 0? I initially thought that the causal mask would be needed for the trailing block instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is a picture of what happens in both unbalanced and load balanced case:
Each GPU_i starts with Q_i and KV_i at the start. These are the causal masked parts along the diagonal. All subsequence iterations are either skipped or non-masked partial computations. In the load balanced case this is where the "half KV" and "half Q" cases come from.
for idx in range(cp_size): | ||
output = output + output_per_steps[idx].astype(jnp.float32) * jnp.exp( | ||
softmax_aux_per_steps[idx] - softmax_aux | ||
).transpose(0, 2, 1, 3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be feasible to transpose softmax_aux inside scan_block to allow for pipelining the transpose?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it doesn't affect the performance lots we can keep the transpose here
Description
Add ring attention as an additional context parallel strategy to Jax fused attention API.
Type of change
Changes
Adds a new context parallel strategy and unit tests.
Checklist: