Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Consolidate FFI and old descriptor implementation for fused attention. #1295

Merged

Conversation

mgoldfarb-nvidia
Copy link
Collaborator

Description

Some customers still depend on the older descriptor based approach to FFI for fused attention. This change consolidates the two methods to ensure we don't diverge in behavior or accidentally miss updates to one or the other.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Consolidate duplicate code.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@mgoldfarb-nvidia
Copy link
Collaborator Author

/te-ci jax

@phu0ngng
Copy link
Collaborator

Hi @mgoldfarb-nvidia, thanks for pushing this PR.

I and @zlsh80826 also discussed a way to make FusedAttn APIs better by utilizing new features from FFI.
One of the related PR is #1289. Reese can comment more on this.

Note that we planed to remove all the legacy custom calls in the future when the new custom calls with FFI are stable and no unexpected changes from XLA. The legacy custom calls will be kept for a few months for fallback only. Customers are expected to switch to use new custom calls and in fact the new custom calls are enabled by default.

@mgoldfarb-nvidia
Copy link
Collaborator Author

mgoldfarb-nvidia commented Oct 28, 2024

Hi @mgoldfarb-nvidia, thanks for pushing this PR.

I and @zlsh80826 also discussed a way to make FusedAttn APIs better by utilizing new features from FFI. One of the related PR is #1289. Reese can comment more on this.

Note that we planed to remove all the legacy custom calls in the future when the new custom calls with FFI are stable and no unexpected changes from XLA. The legacy custom calls will be kept for a few months for fallback only. Customers are expected to switch to use new custom calls and in fact the new custom calls are enabled by default.

Thank @phu0ngng We have some customers depending on the old FFI interface and will definitely want to make sure we don't break too much until we can get them moved over. At a minimum we could try to separate the FFI layer from the underlying TE code just so its easier to rebase changes on other branches.

It looks like #1289 could be merged with this in a straightforward way.

@huanghua1994
Copy link
Collaborator

This code will not pass the compilation:

  • static FusedAttnForwardImpl() has no type but called return ffi_with_cuda_error_check();
  • Error_Type FusedAttnForwardFFI() does not call return ffi_with_cuda_error_check();

Signed-off-by: Michael Goldfarb <mgoldfarb@nvidia.com>
@mgoldfarb-nvidia
Copy link
Collaborator Author

/te-ci jax

@zlsh80826
Copy link
Collaborator

This is quite great! I did want to do the same thing.

@mgoldfarb-nvidia
Copy link
Collaborator Author

/te-ci jax

@phu0ngng phu0ngng merged commit c036765 into NVIDIA:main Oct 30, 2024
13 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants