-
Notifications
You must be signed in to change notification settings - Fork 309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fused_attn_fwd_qkvpacked silently doesn't support 3 or 7 heads #1182
Comments
13 tasks
Hi @ajayjain , thanks for raising this issue. I understand how the h=3 case could be misinterpreted by TE, but I don't think h=7 should be a problem even before the fix. Could you please give PR 1187 a try and let me know if you still have issues with either case please? Thanks. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
When testing the
fused_attn_fwd_qkvpacked
function intransformer_engine.pytorch.cpp_extensions.fused_attn
, the head dimension is dropped in the output if the input QKV matrix has layout "t3hd" and h=3 or h=7Minimal reproducible example:
It should print [1024, 3, 128] instead.
h=1, 2, 4, 5, 6, and 8 work.
Would it be possible to support the h=3 case?
Thanks!
The text was updated successfully, but these errors were encountered: