Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] fused CUDNN attention kernel not properly handling strides #1195

Open
Marks101 opened this issue Sep 23, 2024 · 3 comments · May be fixed by #1214
Open

[PyTorch] fused CUDNN attention kernel not properly handling strides #1195

Marks101 opened this issue Sep 23, 2024 · 3 comments · May be fixed by #1214
Assignees

Comments

@Marks101
Copy link
Contributor

Hello team,

we noticed discrepencies when using the transformer_engine.pytorch.TransformerLayer in combination with fused attention kernels and multi/group-query attention, fuse_qkv_params and qkv_weight_interleaved. All in all, this problem boils down to the following code snippet:

import torch

from transformer_engine.pytorch.attention import FlashAttention, FusedAttention
from transformer_engine.common.recipe import DelayedScaling
import transformer_engine_torch as tex

seqlen, q_heads, kv_heads, kv_channels = 2048, 16, 1, 64

qkv = torch.randn(seqlen, 1, q_heads + 2 * kv_heads, kv_channels, dtype=torch.float16, device="cuda")

q, k, v = qkv.split([q_heads, kv_heads, kv_heads], dim=2)
#q, k, v = [t.contiguous() for t in (q, k, v)]

flash_attn = FlashAttention(1.0)
fused_attn = FusedAttention(1.0)

output_flash = flash_attn(q, k, v, "sbhd_sbhd_sbhd", window_size=(-1, 0))
output_fused = fused_attn(q, k, v, "sbhd_sbhd_sbhd",
                          fused_attention_backend=tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                          window_size=(-1, 0),
                          fp8_meta=dict(recipe=DelayedScaling()))

print("diff:", torch.max(torch.abs(output_fused - output_flash)).item())

On H100 with CUDA 12.5 and CUDNN 9.2.1 we get a max error of 6.1953125. If I uncomment the line that applies contiguous() the error is down to 0.001953125, accordingly I suspect that the issue is related to the handling of strides in the fused attention backend.

Could you please have a look at this and try to reproduce my results? Thanks!

@yaox12
Copy link
Collaborator

yaox12 commented Sep 24, 2024

I think this issue is caused by that we're doing stride checks in get_qkv_layout() in DotProductAttention.

assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!"

If you directly calls FusedAttention, it will not do such checks, while FA does them in https://github.com/Dao-AILab/flash-attention/blob/53a4f341634fcbc96bb999a3c804c192ea14f2ea/flash_attn/flash_attn_interface.py#L90.

cc @cyanguwa

@Marks101
Copy link
Contributor Author

I just had a second look at the q, k, v tensor strides in my example. The strides are (1152, 1152, 64, 1) for all tensors. So at least the requirement that the last stride is 1 is fulfilled 😉

I added the get_qkv_layout() to my sample but this did not fail and returned the sbhd_sbhd_sbhd layout

@cyanguwa cyanguwa self-assigned this Sep 25, 2024
@cyanguwa cyanguwa linked a pull request Sep 27, 2024 that will close this issue
13 tasks
@cyanguwa
Copy link
Collaborator

cyanguwa commented Sep 27, 2024

The requirement for the last dimension's stride to be 1 seems to be fine here. But I think the problem is that the provided q, k, v tensors do not have a sbhd_sbhd_sbhd layout. It's more like a sb(h+2h_g)d layout :). The .contiguous() call will force it to sbhd_sbhd_sbhd, which is why the test will pass then.

The DotProductAttention module calls get_qkv_layout() to run some checks on user inputs and convert them if necessary (in some limited capacity), but the FlashAttention and FusedAttention modules don't do those checks.

The get_qkv_layout() call returns sbhd_sbhd_sbhd because it missed something in the last elseif logic. I've created a PR #1214 to improve the logic. For this specific case, the function will run run_iteratively() twice and force a .contiguous() on them before passing to FlashAttention or FusedAttention.

import os
import torch

from transformer_engine.pytorch.attention import FlashAttention, FusedAttention, DotProductAttention, _attention_backends
from transformer_engine.common.recipe import DelayedScaling
import transformer_engine_torch as tex

seqlen, q_heads, kv_heads, kv_channels = 2048, 16, 1, 64
seqlen_kv = 1024

qkv = torch.randn(seqlen, 1, q_heads + 2 * kv_heads, kv_channels, dtype=torch.float16, device="cuda")

q, k, v = qkv.split([q_heads, kv_heads, kv_heads], dim=2)
#q, k, v = [t.contiguous() for t in (q, k, v)]

flash_attn = DotProductAttention(q_heads, kv_channels=kv_channels, num_gqa_groups=kv_heads, softmax_scale=0.1)
fused_attn = DotProductAttention(q_heads, kv_channels=kv_channels, num_gqa_groups=kv_heads, softmax_scale=0.1)

os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_FLASH_ATTN"] = "1"
output_flash = flash_attn(q, k, v)

os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
output_fused = fused_attn(q, k, v)

print("diff:", torch.max(torch.abs(output_fused - output_flash)).item())

Results:

NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python test.py
[INFO     | DotProductAttention]: Running with FlashAttention backend (version 2.4.2)
[INFO     | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)
diff: 0.0009765625

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants