Skip to content

Commit

Permalink
[PyTorch] Miscellaneous fixes for FA3 attention (#1174)
Browse files Browse the repository at this point in the history
* add qkv descales to FA3

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix sbhd shapes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* force the same dtype when comparing FA3 and cuDNN FP8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "force the same dtype when comparing FA3 and cuDNN FP8"

This reverts commit 19e7f87.

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* force the same dtype when comparing FA3 and cuDNN FP8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add try/except for FA3 when custom qkv descales are not supported

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace FA3 installation warning with a debug logging message

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix lint

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove unused imports

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* avoid varlen_func for FP8 and improve messaging

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add SWA support for FA3

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* change preference reason for FP8 logic

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fix

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
cyanguwa and pre-commit-ci[bot] authored Oct 8, 2024
1 parent c3b3cd2 commit e762592
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 69 deletions.
2 changes: 1 addition & 1 deletion qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
Expand Down
2 changes: 2 additions & 0 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,8 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item()))
logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item()))
try:
if a.dtype != b.dtype:
a = a.to(b.dtype)
torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
except Exception as e:
logging.debug(e)
Expand Down
172 changes: 104 additions & 68 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
from transformer_engine.pytorch.graph import is_graph_capturing

# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2]
_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
_stream_handler = logging.StreamHandler()
_stream_handler.setFormatter(_formatter)

_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
Expand All @@ -100,29 +110,31 @@
_flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7")
_flash_attn_3_plus = False
_use_flash_attn_3 = False
_flash_attn_3_installation_steps = """\
(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper"
(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"`
(3) mkdir -p $python_path/flashattn_hopper
(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py"""
try:
_flash_attn_v3_version = PkgVersion(get_pkg_version("flashattn-hopper"))
_flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.6.1")
_flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.9")
_flash_attn_3_0_0_beta = _flash_attn_3_plus and _flash_attn_v3_version < PkgVersion("3.0.0")
except PackageNotFoundError:
if get_device_compute_capability() == (9, 0) and _NVTE_FLASH_ATTN:
warnings.warn(
"To use flash-attn v3, please use the following commands to install: \n"
"""(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" \n"""
"""(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` \n"""
"""(3) mkdir -p $python_path/flashattn_hopper \n"""
"""(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py"""
fa3_logger = logging.getLogger()
fa3_logger.setLevel(_log_level)
if not fa3_logger.hasHandlers():
fa3_logger.addHandler(_stream_handler)
fa3_logger.debug(
"To use flash-attn v3, please follow these steps to install the flashattn-hopper "
"package: \n%s",
_flash_attn_3_installation_steps,
)
else:
from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flashattn_hopper.flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import
_flash_attn_forward as _flash_attn_forward_v3,
)
from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import
_flash_attn_backward as _flash_attn_backward_v3,
)

_use_flash_attn_3 = True

Expand All @@ -132,18 +144,6 @@
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd


# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2]
_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
_stream_handler = logging.StreamHandler()
_stream_handler.setFormatter(_formatter)

_attention_backends = {
"attention_params": None,
"use_flash_attention": None,
Expand Down Expand Up @@ -348,7 +348,7 @@ def get_attention_backend(
use_fused_attention = False

# Filter: Compute capability
global _flash_attn_3_plus, _use_flash_attn_3
global _use_flash_attn_3
if device_compute_capability < (8, 0):
if use_flash_attention:
logger.debug("Disabling FlashAttention as it requires compute capability sm80+")
Expand All @@ -357,7 +357,7 @@ def get_attention_backend(
logger.debug("Disabling FusedAttention as it requires compute capability sm80+")
use_fused_attention = False
if device_compute_capability < (9, 0):
if use_flash_attention and _flash_attn_3_plus:
if use_flash_attention and _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+")
_use_flash_attn_3 = False

Expand Down Expand Up @@ -438,10 +438,9 @@ def get_attention_backend(
use_flash_attention = False

# Filter: Dropout
if attention_dropout != 0.0 and use_flash_attention:
if _flash_attn_3_plus and _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for dropout")
_use_flash_attn_3 = False
if attention_dropout != 0.0 and use_flash_attention and _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for dropout")
_use_flash_attn_3 = False

# Filter: Context parallelism
# qkv_format | attn_mask_type | attn_bias_type | supported backends
Expand All @@ -461,7 +460,7 @@ def get_attention_backend(
)
use_unfused_attention = False
if context_parallel and use_flash_attention:
if _flash_attn_3_plus and _use_flash_attn_3:
if _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for context parallelism")
_use_flash_attn_3 = False
if fp8 and fp8_meta["recipe"].fp8_dpa:
Expand Down Expand Up @@ -556,7 +555,7 @@ def get_attention_backend(
use_fused_attention = False
if (
use_flash_attention
and _flash_attn_3_plus
and _use_flash_attn_3
and attn_mask_type in ["causal", "padding_causal"]
and max_seqlen_q != max_seqlen_kv
):
Expand Down Expand Up @@ -590,6 +589,15 @@ def get_attention_backend(
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
if (
use_flash_attention
and _use_flash_attn_3
and fp8
and fp8_meta["recipe"].fp8_dpa
and "padding" in attn_mask_type
):
logger.debug("Disabling FlashAttention 3 for FP8 and padding masks")
_use_flash_attn_3 = False

# Filter: Sliding window attention
# backend | window_size | diagonal alignment
Expand Down Expand Up @@ -633,15 +641,6 @@ def get_attention_backend(
attn_mask_type,
)
use_fused_attention = False
if (
use_flash_attention
and (window_size[0] != -1 or window_size[1] not in [-1, 0])
and _flash_attn_3_plus
):
logger.debug(
"Disabling FlashAttention 3 as it does not support sliding window attention"
)
_use_flash_attn_3 = False
if (
use_flash_attention
and (window_size[0] != -1 or window_size[1] not in [-1, 0])
Expand All @@ -662,12 +661,12 @@ def get_attention_backend(
# UnfusedDotProductAttention | no_bias, pre/post_scale_bias |
# | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias
if use_flash_attention and core_attention_bias_type == "alibi":
if _flash_attn_3_plus and _use_flash_attn_3:
if _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for ALiBi")
_use_flash_attn_3 = False
if not _flash_attn_2_4_plus:
logger.debug("Disabling FlashAttention for ALiBi")
use_flash_attention = False
if not _use_flash_attn_3 and not _flash_attn_2_4_plus:
logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+")
use_flash_attention = False

if use_flash_attention and (
core_attention_bias_type not in ["no_bias", "alibi"]
Expand Down Expand Up @@ -827,19 +826,15 @@ def get_attention_backend(
"for performance reasons"
)
use_flash_attention = False

# Select FusedAttention for FP8
# FA3 uses default scaling factors (i.e. 1) in FP8 execution, while FusedAttention takes
# scaling factors from `fp8_meta` and offers more accurate quantization/de-quantization
if (
use_flash_attention
and use_fused_attention
and fused_attention_backend == FusedAttnBackend["FP8"]
and _use_flash_attn_3
):
logger.debug(
"Disabling FlashAttention 3 to give FusedAttention preference as FusedAttention "
"supports more accurate scaling factors in FP8 execution"
"Disabling FlashAttention 3 to give FusedAttention preference for performance reasons "
"in FP8 execution"
)
use_flash_attention = False

Expand Down Expand Up @@ -4963,6 +4958,10 @@ def __init__(
self.attention_type = attention_type
self.layer_number = 1 if layer_number is None else layer_number
self.deterministic = deterministic
self.logger = logging.getLogger("FlashAttention")
self.logger.setLevel(_log_level)
if not self.logger.hasHandlers():
self.logger.addHandler(_stream_handler)

def forward(
self,
Expand Down Expand Up @@ -5033,6 +5032,10 @@ def forward(
x.transpose(0, 1)
for x in (query_layer._data, key_layer._data, value_layer._data)
]
query_layer, key_layer, value_layer = [
Float8Tensor.make_like(x, data=x._data)
for x in (query_layer, key_layer, value_layer)
]
if context_parallel:
query_layer._data, key_layer._data, value_layer._data = [
x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data)
Expand Down Expand Up @@ -5168,33 +5171,62 @@ def forward(
fa_optional_forward_args_thd.append(max_seqlen_q)
fa_optional_forward_args_thd.append(max_seqlen_kv)
if _use_flash_attn_3:
fa_3_optional_forward_kwargs = {}
fa_3_optional_forward_kwargs["window_size"] = window_size
fa_3_optional_forward_kwargs["deterministic"] = self.deterministic
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
activation_dtype = query_layer.dtype
torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True)

def convert_to_torch_float8(tensor, dtype):
out = torch.Tensor().to(device=tensor.device, dtype=dtype)
out.set_(
tensor._data.untyped_storage(),
tensor._data.storage_offset(),
tensor._data.shape,
tensor._data.stride(),
)
return out

if fp8_meta["recipe"].fp8_mha:
assert all(
isinstance(x, Float8Tensor)
for x in [query_layer, key_layer, value_layer]
), "q/k/v must be Float8Tensors for FP8 MHA."
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = query_layer._scale_inv
query_layer, key_layer, value_layer = (
x.to(activation_dtype).to(torch_dtype)
for x in [query_layer, key_layer, value_layer]
)
else:
query_layer, key_layer, value_layer = (
x.to(torch_dtype) for x in [query_layer, key_layer, value_layer]
Float8Tensor.to_float8(x, fp8_dtype=fp8_dtype_forward)
for x in [query_layer, key_layer, value_layer]
)
output, _ = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
deterministic=self.deterministic,
)
fa_3_optional_forward_kwargs["descale_q"] = query_layer._scale_inv
fa_3_optional_forward_kwargs["descale_k"] = key_layer._scale_inv
fa_3_optional_forward_kwargs["descale_v"] = value_layer._scale_inv
query_layer, key_layer, value_layer = (
convert_to_torch_float8(x, torch_dtype)
for x in [query_layer, key_layer, value_layer]
)
try:
output, _ = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
**fa_3_optional_forward_kwargs,
)
except TypeError as e:
if _flash_attn_3_0_0_beta:
e.args = (
e.args[0]
+ ". Please update your FlashAttention 3 (beta) installation as it "
+ "may have added more supported arguments to its API. \n"
+ _flash_attn_3_installation_steps,
) + e.args[1:]
raise

if fp8 and fp8_meta["recipe"].fp8_mha:
output = cast_to_fp8(
output,
Expand Down Expand Up @@ -5228,8 +5260,12 @@ def forward(
if qkv_format == "sbhd":
# (bs)hd -> bs(hd) -> sb(hd)
if fp8 and fp8_meta["recipe"].fp8_mha:
output.reshape(batch_size * max_seqlen_q // cp_size, -1).transpose_2d()
output = output.reshape(batch_size, max_seqlen_q // cp_size, -1)
output = Float8Tensor.make_like(
output,
data=output._data.reshape(batch_size, max_seqlen_q // cp_size, -1)
.transpose(0, 1)
.contiguous(),
)
else:
output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1)
elif qkv_format == "bshd":
Expand Down

0 comments on commit e762592

Please sign in to comment.