-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use torch sdpa implementation in ASR mha #9590
base: main
Are you sure you want to change the base?
Conversation
I also attempted to run the tests in the repository but encountered an issue. NaNs appear when a mask with a fully False row is passed to MHA. Because of such mask, filling the What can be done about this? Should we write a separate test that doesn't use such masks as input? |
@SeanNaren @titu1994 An option of using SDPA is a good thing also because a Triton-based version of FAv2 with custom |
@SeanNaren @titu1994 haha, and now that FAv3 is out, probably PyTorch would integrate it as well in some near term - for maximum brr on H100 :) so having Nemo's Conformer auto-benefitting from this new work would be awesome |
This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days. |
stale bump |
Thanks for the contribution! Just some notes:
|
29ac6c0
to
741be10
Compare
@VahidooX Big thanks for review!
Do you have any script for calculating metrics on LS? And reference metrics? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR looks great! Sorry for the delay in review.
I would be ok with making the default use torch sdpa as True only if we can add a test that runs this function two times - setting the flag to true and false and comparing the output to have 1e-4 or lower mse difference.
Could you add your example code as a unit test somewhere as a check ?
Thanks @titu1994 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks much for the PR!
LGTM, minor question.
I evaluated the PR on HF-Leaderboard datasets, observed no difference in WER. And for LS test set on A6000, improved RTFx by 5%
examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml
Outdated
Show resolved
Hide resolved
I'm working with tests and faced a problem with
|
Yes, that looks like to be incorrect. |
This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days. |
That sounds fine with me. Can you make it False by default. @titu1994 wdyt? |
Sounds ok to me |
Signed-off-by: WoodieDudy <goshagks@gmail.com>
Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>
Signed-off-by: WoodieDudy <goshagks@gmail.com>
Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>
I set the flag everywhere to false by default. And also added an argument |
Let's merge? @titu1994 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks ok for inference at least, lets avoid if for training for now |
Uh, sorry for the lack of responsiveness lately. I used bf16-mixed and PyTorch 2.4.0 with CUDA 12.5. I agree we should merge it disabled by default. Let's take a look again later with the newer backends. We should also enable it for inference by default if possible. |
Just initiated the CI run: looks like some of them are failing like this: https://github.com/NVIDIA/NeMo/actions/runs/11128061705/job/30922390803?pr=9590 address them and once CI passes this is good to go. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Just left some minor comments.
@@ -136,6 +136,8 @@ model: | |||
xscaling: true # scales up the input embeddings by sqrt(d_model) | |||
untie_biases: true # unties the biases of the TransformerXL layers | |||
pos_emb_max_len: 5000 | |||
use_pytorch_sdpa: false # use torch sdpa instead of manual attention | |||
use_pytorch_sdpa_backends: [torch.nn.attention.SDPBackend.MATH] # empty list means all backends https://pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please set the backend list to empty.
@@ -48,6 +49,8 @@ | |||
'PositionalEncoding', | |||
] | |||
|
|||
inf_val = 10000.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use "INF_VAL".
So, the CI is using torch==2.3.0a0+ebedce2 which has a bit different way to setting sdpa backends. What a good way to deal with it? Add condition on pytorch version in a MultiHeadAttention
or update CI torch version? |
Signed-off-by: WoodieDudy <goshagks@gmail.com>
86e60c3
Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>
I suggest adding the condition so that NeMo can be used with earlier PyTorch versions. |
Hola. I changed the mha implementation for the ASR modules so that it uses
torch.nn.functional.scaled_dot_product_attention
.This accelerated forward in the mha by 27% and backward by 17% on the A100.
Pytorch sdpa is continuously being optimized, ensuring that we benefit from the latest performance improvements.
My code uses memory efficient backend in sdpa because flash attention doesn't support custom attention bias. There is ongoing work to contribute custom bias support in the flash-attention repository. PR.
What else do I need to do to merge this pr?
Usage
There is also my benchmark:
PR Type:
Who can review?
cc @titu1994 @SeanNaren
Additional Information