-
Notifications
You must be signed in to change notification settings - Fork 308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] fused CUDNN attention kernel not properly handling strides #1195
Comments
I think this issue is caused by that we're doing stride checks in
If you directly calls FusedAttention , it will not do such checks, while FA does them in https://github.com/Dao-AILab/flash-attention/blob/53a4f341634fcbc96bb999a3c804c192ea14f2ea/flash_attn/flash_attn_interface.py#L90.
cc @cyanguwa |
I just had a second look at the q, k, v tensor strides in my example. The strides are I added the |
The requirement for the last dimension's stride to be 1 seems to be fine here. But I think the problem is that the provided The The
Results:
|
Hello team,
we noticed discrepencies when using the
transformer_engine.pytorch.TransformerLayer
in combination with fused attention kernels and multi/group-query attention,fuse_qkv_params
andqkv_weight_interleaved
. All in all, this problem boils down to the following code snippet:On H100 with CUDA 12.5 and CUDNN 9.2.1 we get a max error of 6.1953125. If I uncomment the line that applies
contiguous()
the error is down to 0.001953125, accordingly I suspect that the issue is related to the handling of strides in the fused attention backend.Could you please have a look at this and try to reproduce my results? Thanks!
The text was updated successfully, but these errors were encountered: