Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] fused CUDNN attention kernel and sliding window attention #1197

Closed
Marks101 opened this issue Sep 23, 2024 · 3 comments · Fixed by #1212
Closed

[PyTorch] fused CUDNN attention kernel and sliding window attention #1197

Marks101 opened this issue Sep 23, 2024 · 3 comments · Fixed by #1212
Assignees

Comments

@Marks101
Copy link
Contributor

Hello team,

we have been noticing some pretty large deviations between the attention output of flash/unfused attention versus the fused attention kernels when sliding window attention is active. The following sample illustrates this:

import torch

from transformer_engine.pytorch.attention import FlashAttention, FusedAttention, UnfusedDotProductAttention, get_swa_mask
from transformer_engine.common.recipe import DelayedScaling
import transformer_engine_torch as tex

window_size = (1024, 0)
seqlen, num_heads, kv_channels = 2048, 64, 64

q, k, v = [torch.randn(seqlen, 1, num_heads, kv_channels, dtype=torch.float16, device="cuda") for _ in range(3)]

flash_attn = FlashAttention(1.0)
fused_attn = FusedAttention(1.0)
unfused_attn = UnfusedDotProductAttention(1.0)

output_flash = flash_attn(q, k, v, "sbhd_sbhd_sbhd", window_size=window_size)
output_fused = fused_attn(q, k, v, "sbhd_sbhd_sbhd",
                          fused_attention_backend=tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                          window_size=window_size,
                          fp8_meta=dict(recipe=DelayedScaling()))

attention_mask = torch.ones(1, 1, seqlen, seqlen, dtype=torch.bool, device="cuda")
attn_mask_type, attention_mask = get_swa_mask(window_size, seqlen, seqlen, "causal", attention_mask)
output_unfused = unfused_attn(q, k, v, attn_mask_type=attn_mask_type, attention_mask=attention_mask)

print("diff flash vs unfused:", torch.max(torch.abs(output_flash - output_unfused)).item())
print("diff fused vs unfused:", torch.max(torch.abs(output_fused - output_unfused)).item())

The output we see on H100 and CUDA 12.5 with CUDNN 9.2.1 is:

diff flash vs unfused: 0.03076171875
diff fused vs unfused: 4.8828125

The later one seems rather large. Can you reproduce these results?

@ksivaman
Copy link
Member

@cyanguwa Do you know what could be causing this?

@cyanguwa
Copy link
Collaborator

Hi @Marks101 ,

Thanks for raising this issue. I seem to have overlooked the different window_size definition in cuDNN. cuDNN supports sliding window (i - window_size_left, i], exclusive of the i - window_size_left element, whereas the original paper, flash-attn and TE unfused DPA have used the definition of [i - window_size_left, i + window_size_right], which is inclusive of the boundary elements. Please give #1212 a try and let me know if there's still any issues. Thanks!

Results:

diff flash vs unfused: 0.0330810546875
diff fused vs unfused: 0.033203125
diff flash vs   fused: 0.001953125

@Marks101
Copy link
Contributor Author

Hello Charlene,
oh I see, that makes sense. I just tested your fix in our training environment and I can confirm that the issue is fixed.
Thanks for looking into this so quickly and sharing the details 🥳

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants