Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[C++/PyTorch] Add alibi_slopes support (#608)
* test alibi between fa and fu Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * move alibi slopes and bias to global to avoid repeating calculation Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix alibi slopes/bias generation Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix _is_flash_attention_supported to allow alibi type Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disable padding mask when alibi is used for fused attn arbi backend Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add support for custom [n_heads] alibi_slopes in flash, fused, unfused attention Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove alibi_type=none tests as they are unnecessary Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update cudnn-frontend to 1.0.2 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change bias/dbias shape to allow b,1/1,h/b,h in arbi backend Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak tests for arbi post_scale_bias [1,h,s,s] or alibi_slopes [n_heads] Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change bias/dbias shape in max512 backend - incomplete Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove max512 changes from last commit and disable max512 (and arbi temporarily) for [b, h, s, s]; pending cuDNN backend support Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up and tweak backend selection logic Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace || with () in docstring Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix bias shape for max512 backend Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * combine slopes/bias generation to one function get_alibi() and fix alibi tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix PR557 bugs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Update transformer_engine/pytorch/attention.py Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: cyanguwa <8636796+cyanguwa@users.noreply.github.com> * encapsulate global alibi tensors into a dict cache Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * reduce alibi slopes test size Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update to cudnn-frontend 1.0.3 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * use dBias shape to define bias_b/bias_h because jax materializes dBias rather than Bias in bwd abstract Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by: cyanguwa <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
- Loading branch information