-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TE/JAX] XLA FFI calls for Softmax and FusedAttnBackward #1319
[TE/JAX] XLA FFI calls for Softmax and FusedAttnBackward #1319
Conversation
/te-ci jax L1 |
@zlsh80826 The forward test in |
Signed-off-by: Hua Huang <huah@nvidia.com>
FusedAttnBackward passed all testes in test_fused_attn.py. Dequantize is not used currently; finish it for completeness. Signed-off-by: Hua Huang <huah@nvidia.com>
Signed-off-by: Hua Huang <huah@nvidia.com>
for more information, see https://pre-commit.ci
CI L1 tests passed. Rebase to the main branch to verify #1314 |
73e00ba
to
e891678
Compare
/te-ci jax L1 |
Signed-off-by: Hua Huang <huah@nvidia.com>
/te-ci jax L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for the fused attn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
This PR introduced the following primitives implemented with the new custom calls:
ScaledSoftmaxFwdPrimitive
ScaledSoftmaxBwdPrimitive
ScaledMaskedSoftmaxFwdPrimitive
ScaledMaskedSoftmaxBwdPrimitive
ScaledUpperTriangMaskedSoftmaxFwdPrimitive
ScaledUpperTriangMaskedSoftmaxBwdPrimitive
FusedAttnBwdPrimitive
Also added
DequantizeFFI()
intransformer_engine/jax/csrc/extensions/quantization.cpp
although currently no Python function calls dequantize explicitly.All C++ functions in
transformer_engine/jax/csrc/extensions
have FFI after this PR.Type of change
Checklist: