-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Miscellaneous fixes for FA3 attention #1174
Conversation
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
for more information, see https://pre-commit.ci
/te-ci pytorch |
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
for more information, see https://pre-commit.ci
/te-ci pytorch |
FA3: pipeline 18489052 |
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
This reverts commit 19e7f87. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
for more information, see https://pre-commit.ci
/te-ci pytorch |
for more information, see https://pre-commit.ci
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
/te-ci pytorch |
FA3 pipeline 18528978 |
/te-ci pytorch |
LGTM. Only one small question or comment: seems like FA3 only can support FP8 with BSHD/SBHD format, THD format is not supported with FP8. Should we added an assert message for this in TE? Anyway, this will finally trigger error in Tri Dao's code, but I think better to tell users this in TE also. |
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
for more information, see https://pre-commit.ci
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
for more information, see https://pre-commit.ci
/te-ci pytorch |
FA3: 19002912 |
@xrennvidia do you mind taking another look at the PR? I made a couple more changes after your last review. Thanks! |
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
for more information, see https://pre-commit.ci
/te-ci pytorch |
FA3: 19123312 |
Description
This PR makes a few changes to the FA3 attention path.
descale_q
,descale_k
anddescale_v
to FA3 FP8 call. This allows for custom descaling factors instead of the default 1s forq
,k
andv
. This requires FA3 PR 1210 to be in your FA3 installation.flash_attn_func
for FP8 sinceflash_attn_varlen_func
does not support FP8 yet.qkv_format=sbhd
case whenfp8_mha=true
.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: