Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Support Ring Attention (Context Parallelism) #1059

Merged
merged 2 commits into from
Nov 11, 2024

Conversation

mingxu1067
Copy link
Collaborator

@mingxu1067 mingxu1067 commented Jul 30, 2024

Description

Add ring attention as an additional context parallel strategy to Jax fused attention API.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Adds a new context parallel strategy and unit tests.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@mingxu1067 mingxu1067 marked this pull request as draft July 30, 2024 17:57
@mingxu1067 mingxu1067 force-pushed the mingh/ring_attn_primitive branch 2 times, most recently from a42f6a2 to 7a147c8 Compare July 30, 2024 20:24
@mgoldfarb-nvidia mgoldfarb-nvidia force-pushed the mingh/ring_attn_primitive branch 7 times, most recently from f50f6f0 to 9fb3dc3 Compare November 5, 2024 22:07
@mgoldfarb-nvidia mgoldfarb-nvidia marked this pull request as ready for review November 5, 2024 22:09
@mgoldfarb-nvidia
Copy link
Collaborator

/te-ci jax L1


@staticmethod
@cache
def use_scanloop():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit ugly but a necessity of using scan loop which is preferred as it gives XLA more control over unrolling. We won't prevent not using scan but it currently warns the user to update their XLA flags. NVTE_FUSED_RING_ATTENTION_USE_SCAN if undefined will default to 1.

Hopefully --xla_experimental_ignore_channel_id becomes default in XLA:GPU at some point.


# Combine KV tensors if separate for better permute scheduling and performance.
# Eventually XLA should perform this automatically.
kv = helper.stack_kv(k, v)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XLA:GPU does a poor job scheduling 2+ permutes with overlap so we manually combine K and V into our combined KV format. There is a plan to add functionality into XLA to do this fusion automatically.

Copy link
Collaborator

@zlsh80826 zlsh80826 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the impressive work! I’ll need a few days to complete my review of the partition details within the ring attention.

transformer_engine/jax/cpp_extensions/attention.py Outdated Show resolved Hide resolved
transformer_engine/jax/cpp_extensions/misc.py Outdated Show resolved Hide resolved
transformer_engine/jax/cpp_extensions/misc.py Outdated Show resolved Hide resolved
transformer_engine/jax/cpp_extensions/attention.py Outdated Show resolved Hide resolved
@mgoldfarb-nvidia mgoldfarb-nvidia force-pushed the mingh/ring_attn_primitive branch 5 times, most recently from e35fd0d to 51922dc Compare November 7, 2024 23:28
@mgoldfarb-nvidia
Copy link
Collaborator

/te-ci jax L1

Signed-off-by: Michael Goldfarb <mgoldfarb@nvidia.com>
Signed-off-by: Ming Huang <mingh@nvidia.com>
@mgoldfarb-nvidia
Copy link
Collaborator

/te-ci jax L1

Copy link
Collaborator

@zlsh80826 zlsh80826 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

return lax.cond((idx <= cp_rank), no_mask_compute, skip_compute)

output_per_step, softmax_aux_per_step = lax.cond(
idx == 0, causal_mask_compute, jax_cond_wrap
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify why causal computation is applied to idx == 0? I initially thought that the causal mask would be needed for the trailing block instead.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a picture of what happens in both unbalanced and load balanced case:

image

Each GPU_i starts with Q_i and KV_i at the start. These are the causal masked parts along the diagonal. All subsequence iterations are either skipped or non-masked partial computations. In the load balanced case this is where the "half KV" and "half Q" cases come from.

for idx in range(cp_size):
output = output + output_per_steps[idx].astype(jnp.float32) * jnp.exp(
softmax_aux_per_steps[idx] - softmax_aux
).transpose(0, 2, 1, 3)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be feasible to transpose softmax_aux inside scan_block to allow for pipelining the transpose?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it doesn't affect the performance lots we can keep the transpose here

@mgoldfarb-nvidia mgoldfarb-nvidia merged commit bfddb48 into NVIDIA:main Nov 11, 2024
21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants