-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX] Consolidate FFI and old descriptor implementation for fused attention. #1295
[JAX] Consolidate FFI and old descriptor implementation for fused attention. #1295
Conversation
7ab4d8f
to
f7c936d
Compare
/te-ci jax |
Hi @mgoldfarb-nvidia, thanks for pushing this PR. I and @zlsh80826 also discussed a way to make FusedAttn APIs better by utilizing new features from FFI. Note that we planed to remove all the legacy custom calls in the future when the new custom calls with FFI are stable and no unexpected changes from XLA. The legacy custom calls will be kept for a few months for fallback only. Customers are expected to switch to use new custom calls and in fact the new custom calls are enabled by default. |
Thank @phu0ngng We have some customers depending on the old FFI interface and will definitely want to make sure we don't break too much until we can get them moved over. At a minimum we could try to separate the FFI layer from the underlying TE code just so its easier to rebase changes on other branches. It looks like #1289 could be merged with this in a straightforward way. |
This code will not pass the compilation:
|
Signed-off-by: Michael Goldfarb <mgoldfarb@nvidia.com>
f7c936d
to
f529ed3
Compare
/te-ci jax |
This is quite great! I did want to do the same thing. |
/te-ci jax |
Description
Some customers still depend on the older descriptor based approach to FFI for fused attention. This change consolidates the two methods to ensure we don't diverge in behavior or accidentally miss updates to one or the other.
Fixes # (issue)
Type of change
Changes
Consolidate duplicate code.
Checklist: