Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use torch sdpa implementation in ASR mha #9590

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

WoodieDudy
Copy link

@WoodieDudy WoodieDudy commented Jul 2, 2024

Hola. I changed the mha implementation for the ASR modules so that it uses torch.nn.functional.scaled_dot_product_attention.
This accelerated forward in the mha by 27% and backward by 17% on the A100.
Pytorch sdpa is continuously being optimized, ensuring that we benefit from the latest performance improvements.
My code uses memory efficient backend in sdpa because flash attention doesn't support custom attention bias. There is ongoing work to contribute custom bias support in the flash-attention repository. PR.

What else do I need to do to merge this pr?

Usage

There is also my benchmark:

import torch
import torch.nn as nn
import torch.utils.benchmark as benchmark
from nemo.collections.asr.parts.submodules.multi_head_attention import RelPositionMultiHeadAttention

torch.manual_seed(123)

device = "cuda"
batch_size = 32
seq_len = 1024
d_model = 512
n_head = 8

query = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True)
key = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True)
value = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True)
mask = torch.ones(batch_size, seq_len, seq_len, device=device, requires_grad=False)
mask = torch.triu(mask, diagonal=1).bool() # mask: True - make zero, False - leave unchanged 
mask = None
pos_emb = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True)

attention_sdpa = RelPositionMultiHeadAttention(n_head, d_model, 0.0, None, None, use_pytorch_sdpa=True).to(device)
attention_original = RelPositionMultiHeadAttention(n_head, d_model, 0.0, None, None, use_pytorch_sdpa=False).to(device)
for original_param, sdpa_param in zip(attention_original.parameters(), attention_sdpa.parameters()):
    original_param.data.copy_(sdpa_param.data)

# attention_sdpa = torch.compile(attention_sdpa)
# attention_original = torch.compile(attention_original)


def measure_time(attention, query, key, value, mask, pos_emb):
    timer = benchmark.Timer(
        stmt='attention(query, key, value, mask, pos_emb);torch.cuda.synchronize()',
        setup='torch.cuda.synchronize()',
        globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask, 'pos_emb': pos_emb}
    )

    with torch.no_grad():
        torch.cuda.synchronize()
        results = timer.blocked_autorange(min_run_time=10)
        forward_time = results.mean
        output = attention(query, key, value, mask, pos_emb)
    return forward_time, output


def measure_fwd_bwd_time(attention, query, key, value, mask, pos_emb):
    timer = benchmark.Timer(
        stmt='loss=attention(query, key, value, mask, pos_emb).sum();torch.cuda.synchronize();loss.backward();torch.cuda.synchronize()',
        globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask, 'pos_emb': pos_emb}
    )
    torch.cuda.synchronize()
    results = timer.blocked_autorange(min_run_time=10)
    fwd_bwd_time = results.mean
    return fwd_bwd_time


time_fwd_original, output_original = measure_time(attention_original, query, key, value, mask, pos_emb)
time_fwd_sdpa, output_sdpa = measure_time(attention_sdpa, query, key, value, mask, pos_emb)

print(f"Original implementation time: {time_fwd_original:.6f} seconds")
print(f"SDPA implementation time: {time_fwd_sdpa:.6f} seconds")
print(f"SDPA boost {(time_fwd_original - time_fwd_sdpa) / time_fwd_original * 100:.2f}%")

time_fwd_bwd_original = measure_fwd_bwd_time(attention_original, query, key, value, mask, pos_emb)
time_fwd_bwd_sdpa = measure_fwd_bwd_time(attention_sdpa, query, key, value, mask, pos_emb)
time_bwd_original = time_fwd_bwd_original - time_fwd_original
time_bwd_sdpa = time_fwd_bwd_sdpa - time_fwd_sdpa

print(f"Original implementation backward time: {time_bwd_original:.6f} seconds")
print(f"SDPA implementation backward time: {time_bwd_sdpa:.6f} seconds")
print(f"SDPA backward boost {(time_bwd_original - time_bwd_sdpa) / time_bwd_original * 100:.2f}%")

print(f"Outputs are {'the same' if torch.allclose(output_original, output_sdpa, atol=1e-5) else 'different'}")

# Original implementation time: 0.049075 seconds
# SDPA implementation time: 0.035598 seconds
# SDPA boost 27.46%
# Original implementation backward time: 0.127004 seconds
# SDPA implementation backward time: 0.104986 seconds
# SDPA backward boost 17.34%
# Outputs are the same

PR Type:

  • New Feature
  • Bugfix
  • Documentation

Who can review?

cc @titu1994 @SeanNaren

Additional Information

@github-actions github-actions bot added the ASR label Jul 2, 2024
@WoodieDudy
Copy link
Author

I also attempted to run the tests in the repository but encountered an issue. NaNs appear when a mask with a fully False row is passed to MHA. Because of such mask, filling the matrix_bd with -inf values using matrix_bd.masked_fill_(mask.logical_not(), float("-inf")) results in a row of only -inf, and after the softmax, this entire row becomes NaNs. I am unsure how to resolve this since the softmax and multiplication by value occur within torch.nn.functional.scaled_dot_product_attention, and I cannot intervene. In your implementation, this is handled by manually filling with zeros attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0).

What can be done about this? Should we write a separate test that doesn't use such masks as input?

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Jul 9, 2024

@SeanNaren @titu1994 An option of using SDPA is a good thing also because a Triton-based version of FAv2 with custom attn_bias support (FlexAttention) is being added into PyTorch core: pytorch/pytorch#130250 (comment), so Conformer attention can benefit in the future from the speed-ups and proper compilation of SDPA in core PyTorch developments

@vadimkantorov
Copy link
Contributor

@SeanNaren @titu1994 haha, and now that FAv3 is out, probably PyTorch would integrate it as well in some near term - for maximum brr on H100 :) so having Nemo's Conformer auto-benefitting from this new work would be awesome

@WoodieDudy
Copy link
Author

cc @redoctopus @jbalam-nv @okuchaiev

Copy link
Contributor

github-actions bot commented Aug 2, 2024

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

@github-actions github-actions bot added the stale label Aug 2, 2024
@vadimkantorov
Copy link
Contributor

stale bump

@VahidooX
Copy link
Collaborator

VahidooX commented Aug 7, 2024

Thanks for the contribution!
@titu1994 please take a look at this PR as it looks like an interesting addition to speedup conformer models.

Just some notes:

  • Please use -10000 instead of -inf if it is possible as -inf may cause NAN with some data types.

  • Please add it as a config to the config files somewhere like here to be able to control it from configs:


    Name is "use_pytorch_sdpa"?

  • I suggest to make it True as default if we can make sure it works in all cases? @titu1994 what do you think?

  • Please evaluate one of the pretrained models on NGC on a test-other LS to make sure that it produces the same exact output and accuracy.

  • You need to set the dropout to zero manually in non-training model as sdpa does not respect that and it always uses the dropout.

  • Have you used matrix_ac in your code/calculations?

@WoodieDudy
Copy link
Author

WoodieDudy commented Aug 12, 2024

@VahidooX Big thanks for review!

  • I replaced -inf with -10000
  • Added use_pytorch_sdpa to config
  • Fixed dropout for torch sdpa
  • I dont calculate matrix_ac manually but it is calculated under the hood of torch.nn.functional.scaled_dot_product_attention (look on implementation example).
    attn_weight = q_with_bias_u @ key.transpose(-2, -1) * scale_factor
    # so matrix_ac would be equivalent to attn_weight

Do you have any script for calculating metrics on LS? And reference metrics?

Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR looks great! Sorry for the delay in review.

I would be ok with making the default use torch sdpa as True only if we can add a test that runs this function two times - setting the flag to true and false and comparing the output to have 1e-4 or lower mse difference.

Could you add your example code as a unit test somewhere as a check ?

@WoodieDudy
Copy link
Author

Thanks @titu1994
Okey, I'll try to add tests.

nithinraok
nithinraok previously approved these changes Aug 12, 2024
Copy link
Collaborator

@nithinraok nithinraok left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks much for the PR!

LGTM, minor question.

I evaluated the PR on HF-Leaderboard datasets, observed no difference in WER. And for LS test set on A6000, improved RTFx by 5%

nemo/collections/asr/modules/conformer_encoder.py Outdated Show resolved Hide resolved
@WoodieDudy
Copy link
Author

I'm working with tests and faced a problem with tests/collections/asr/test_conformer_encoder.py:test_stochastic_depth_forward.
But I think that in this test, the data in random_length is incorrect, because random_length must one dim with shape batch, not two dims. Am I right?

random_length = torch.tensor([2, 2], dtype=torch.int64)

@VahidooX
Copy link
Collaborator

I'm working with tests and faced a problem with tests/collections/asr/test_conformer_encoder.py:test_stochastic_depth_forward. But I think that in this test, the data in random_length is incorrect, because random_length must one dim with shape batch, not two dims. Am I right?

random_length = torch.tensor([2, 2], dtype=torch.int64)

Yes, that looks like to be incorrect.

Copy link
Contributor

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

@github-actions github-actions bot added the stale label Sep 25, 2024
@nithinraok
Copy link
Collaborator

That sounds fine with me. Can you make it False by default. @titu1994 wdyt?

@titu1994
Copy link
Collaborator

Sounds ok to me

@github-actions github-actions bot removed the stale label Sep 26, 2024
WoodieDudy and others added 4 commits September 26, 2024 22:22
Signed-off-by: WoodieDudy <goshagks@gmail.com>
Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>
Signed-off-by: WoodieDudy <goshagks@gmail.com>
Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>
@WoodieDudy
Copy link
Author

I set the flag everywhere to false by default. And also added an argument use_pytorch_sdpa_backends in which you can set the list of backends for sdpa

@WoodieDudy
Copy link
Author

Let's merge? @titu1994

nithinraok
nithinraok previously approved these changes Sep 27, 2024
Copy link
Collaborator

@nithinraok nithinraok left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! @VahidooX @titu1994 @pzelasko for final review

@titu1994
Copy link
Collaborator

It looks ok for inference at least, lets avoid if for training for now

@WoodieDudy
Copy link
Author

@VahidooX @pzelasko 👀

@pzelasko
Copy link
Collaborator

pzelasko commented Oct 1, 2024

Uh, sorry for the lack of responsiveness lately. I used bf16-mixed and PyTorch 2.4.0 with CUDA 12.5. I agree we should merge it disabled by default. Let's take a look again later with the newer backends.

We should also enable it for inference by default if possible.

pzelasko
pzelasko previously approved these changes Oct 1, 2024
@nithinraok
Copy link
Collaborator

Just initiated the CI run: looks like some of them are failing like this: https://github.com/NVIDIA/NeMo/actions/runs/11128061705/job/30922390803?pr=9590

address them and once CI passes this is good to go. Thanks!

VahidooX
VahidooX previously approved these changes Oct 1, 2024
Copy link
Collaborator

@VahidooX VahidooX left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Just left some minor comments.

@@ -136,6 +136,8 @@ model:
xscaling: true # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000
use_pytorch_sdpa: false # use torch sdpa instead of manual attention
use_pytorch_sdpa_backends: [torch.nn.attention.SDPBackend.MATH] # empty list means all backends https://pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please set the backend list to empty.

@@ -48,6 +49,8 @@
'PositionalEncoding',
]

inf_val = 10000.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use "INF_VAL".

@WoodieDudy
Copy link
Author

So, the CI is using torch==2.3.0a0+ebedce2 which has a bit different way to setting sdpa backends. What a good way to deal with it? Add condition on pytorch version in a MultiHeadAttention __init__ and map backend names to bools

if backend == MATH:
    enable_math = True
if backend == ...:
    enable_mem_efficient = True
with torch.backends.cuda.sdp_kernel(enable_math=enable_math, enable_mem_efficient=enable_mem_efficient):

or update CI torch version?

Signed-off-by: WoodieDudy <goshagks@gmail.com>
@WoodieDudy WoodieDudy dismissed stale reviews from VahidooX, pzelasko, and nithinraok via 86e60c3 October 1, 2024 19:09
Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>
@pzelasko
Copy link
Collaborator

pzelasko commented Oct 2, 2024

I suggest adding the condition so that NeMo can be used with earlier PyTorch versions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants