From bcdc4d14861d91befa28ba94823aa274ac7cf9e4 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 9 Sep 2024 17:23:41 -0700 Subject: [PATCH 01/22] add qkv descales to FA3 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index f8ba46b2ea..1d91862f4e 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5027,24 +5027,38 @@ def forward( fa_optional_forward_args_thd.append(max_seqlen_q) fa_optional_forward_args_thd.append(max_seqlen_kv) if _use_flash_attn_3: + fa_optional_forward_kwargs_fp8 = {} if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) activation_dtype = query_layer.dtype torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True) + def convert_to_torch_float8(tensor, dtype): + out = torch.Tensor().to(device=tensor.device, dtype=dtype) + out.set_( + tensor._data.untyped_storage(), + tensor._data.storage_offset(), + tensor._data.shape, + tensor._data.stride() + ) + return out if fp8_meta["recipe"].fp8_mha: assert all( isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer] ), "q/k/v must be Float8Tensors for FP8 MHA." fp8_meta["scaling_fwd"].scale_inv[META_QKV] = query_layer._scale_inv - query_layer, key_layer, value_layer = ( - x.to(activation_dtype).to(torch_dtype) - for x in [query_layer, key_layer, value_layer] - ) else: query_layer, key_layer, value_layer = ( - x.to(torch_dtype) for x in [query_layer, key_layer, value_layer] + Float8Tensor.to_float8(x, fp8_dtype=fp8_dtype_forward) + for x in [query_layer, key_layer, value_layer] ) + fa_optional_forward_kwargs_fp8["descale_q"] = query_layer._scale_inv + fa_optional_forward_kwargs_fp8["descale_k"] = key_layer._scale_inv + fa_optional_forward_kwargs_fp8["descale_v"] = value_layer._scale_inv + query_layer, key_layer, value_layer = ( + convert_to_torch_float8(x, torch_dtype) + for x in [query_layer, key_layer, value_layer] + ) output, _ = func( query_layer, key_layer, @@ -5053,6 +5067,7 @@ def forward( softmax_scale=self.softmax_scale, causal="causal" in attn_mask_type, deterministic=self.deterministic, + **fa_optional_forward_kwargs_fp8, ) if fp8 and fp8_meta["recipe"].fp8_mha: output = cast_to_fp8( From 3ed49d088be79f8eeb65f073df50bb936d9e53e2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Sep 2024 00:28:48 +0000 Subject: [PATCH 02/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/attention.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 1d91862f4e..842d8d19be 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5032,15 +5032,17 @@ def forward( fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) activation_dtype = query_layer.dtype torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True) + def convert_to_torch_float8(tensor, dtype): out = torch.Tensor().to(device=tensor.device, dtype=dtype) out.set_( tensor._data.untyped_storage(), tensor._data.storage_offset(), tensor._data.shape, - tensor._data.stride() + tensor._data.stride(), ) return out + if fp8_meta["recipe"].fp8_mha: assert all( isinstance(x, Float8Tensor) From 1db61e2f0d46a52ea70c59bc66a0219a6eccdc33 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 17 Sep 2024 13:51:42 -0700 Subject: [PATCH 03/22] fix sbhd shapes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 842d8d19be..e138f946e8 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -4892,6 +4892,10 @@ def forward( x.transpose(0, 1).contiguous() for x in (query_layer._data, key_layer._data, value_layer._data) ] + query_layer, key_layer, value_layer = [ + Float8Tensor.make_like(x, data=x._data) + for x in (query_layer, key_layer, value_layer) + ] elif qkv_format in ["bshd", "thd"]: query_layer._data, key_layer._data, value_layer._data = [ x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data) @@ -5104,8 +5108,12 @@ def convert_to_torch_float8(tensor, dtype): if qkv_format == "sbhd": # (bs)hd -> bs(hd) -> sb(hd) if fp8 and fp8_meta["recipe"].fp8_mha: - output.reshape(batch_size * max_seqlen_q // cp_size, -1).transpose_2d() - output = output.reshape(batch_size, max_seqlen_q // cp_size, -1) + output = Float8Tensor.make_like( + output, + data=output._data.reshape( + batch_size, max_seqlen_q // cp_size, -1 + ).transpose(0,1).contiguous() + ) else: output = ( output.view(batch_size, max_seqlen_q // cp_size, -1) From 6a8666054d293e980fecc6a14faad09964981e51 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Sep 2024 20:52:22 +0000 Subject: [PATCH 04/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/attention.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index e138f946e8..296b68af8a 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5110,10 +5110,10 @@ def convert_to_torch_float8(tensor, dtype): if fp8 and fp8_meta["recipe"].fp8_mha: output = Float8Tensor.make_like( output, - data=output._data.reshape( - batch_size, max_seqlen_q // cp_size, -1 - ).transpose(0,1).contiguous() - ) + data=output._data.reshape(batch_size, max_seqlen_q // cp_size, -1) + .transpose(0, 1) + .contiguous(), + ) else: output = ( output.view(batch_size, max_seqlen_q // cp_size, -1) From 19e7f877026a19a32d2f02c6c9de20df4ae2e064 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 18 Sep 2024 16:37:06 -0700 Subject: [PATCH 05/22] force the same dtype when comparing FA3 and cuDNN FP8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_fused_attn.py | 2 ++ transformer_engine/pytorch/attention.py | 40 +++++++++++++++------ 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index d110dece53..14456010b4 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -1319,6 +1319,8 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol): logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item())) logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item())) try: + if a.dtype != b.dtype: + a = a.to(b.dtype) torch.testing.assert_close(a, b, atol=atol, rtol=rtol) except Exception as e: logging.debug(e) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 817f4bb62e..4dd70ade4e 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -4828,6 +4828,10 @@ def __init__( self.attention_type = attention_type self.layer_number = 1 if layer_number is None else layer_number self.deterministic = deterministic + self.logger = logging.getLogger("FlashAttention") + self.logger.setLevel(_log_level) + if not self.logger.hasHandlers(): + self.logger.addHandler(_stream_handler) def forward( self, @@ -5067,16 +5071,32 @@ def convert_to_torch_float8(tensor, dtype): convert_to_torch_float8(x, torch_dtype) for x in [query_layer, key_layer, value_layer] ) - output, _ = func( - query_layer, - key_layer, - value_layer, - *fa_optional_forward_args_thd, - softmax_scale=self.softmax_scale, - causal="causal" in attn_mask_type, - deterministic=self.deterministic, - **fa_optional_forward_kwargs_fp8, - ) + try: + output, _ = func( + query_layer, + key_layer, + value_layer, + *fa_optional_forward_args_thd, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + deterministic=self.deterministic, + **fa_optional_forward_kwargs_fp8, + ) + except TypeError: + self.logger.debug( + "Running with default q, k, v descales, i.e. 1s. To enable custom " + "descales, please install flashattn-hopper (FA3) with this PR: " + "https://github.com/Dao-AILab/flash-attention/pull/1210." + ) + output, _ = func( + query_layer, + key_layer, + value_layer, + *fa_optional_forward_args_thd, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + deterministic=self.deterministic, + ) if fp8 and fp8_meta["recipe"].fp8_mha: output = cast_to_fp8( output, From bff80b683664909bfdf75906f32b16f07960d171 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 18 Sep 2024 16:40:04 -0700 Subject: [PATCH 06/22] Revert "force the same dtype when comparing FA3 and cuDNN FP8" This reverts commit 19e7f877026a19a32d2f02c6c9de20df4ae2e064. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_fused_attn.py | 2 -- transformer_engine/pytorch/attention.py | 40 ++++++--------------- 2 files changed, 10 insertions(+), 32 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 14456010b4..d110dece53 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -1319,8 +1319,6 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol): logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item())) logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item())) try: - if a.dtype != b.dtype: - a = a.to(b.dtype) torch.testing.assert_close(a, b, atol=atol, rtol=rtol) except Exception as e: logging.debug(e) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 4dd70ade4e..817f4bb62e 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -4828,10 +4828,6 @@ def __init__( self.attention_type = attention_type self.layer_number = 1 if layer_number is None else layer_number self.deterministic = deterministic - self.logger = logging.getLogger("FlashAttention") - self.logger.setLevel(_log_level) - if not self.logger.hasHandlers(): - self.logger.addHandler(_stream_handler) def forward( self, @@ -5071,32 +5067,16 @@ def convert_to_torch_float8(tensor, dtype): convert_to_torch_float8(x, torch_dtype) for x in [query_layer, key_layer, value_layer] ) - try: - output, _ = func( - query_layer, - key_layer, - value_layer, - *fa_optional_forward_args_thd, - softmax_scale=self.softmax_scale, - causal="causal" in attn_mask_type, - deterministic=self.deterministic, - **fa_optional_forward_kwargs_fp8, - ) - except TypeError: - self.logger.debug( - "Running with default q, k, v descales, i.e. 1s. To enable custom " - "descales, please install flashattn-hopper (FA3) with this PR: " - "https://github.com/Dao-AILab/flash-attention/pull/1210." - ) - output, _ = func( - query_layer, - key_layer, - value_layer, - *fa_optional_forward_args_thd, - softmax_scale=self.softmax_scale, - causal="causal" in attn_mask_type, - deterministic=self.deterministic, - ) + output, _ = func( + query_layer, + key_layer, + value_layer, + *fa_optional_forward_args_thd, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + deterministic=self.deterministic, + **fa_optional_forward_kwargs_fp8, + ) if fp8 and fp8_meta["recipe"].fp8_mha: output = cast_to_fp8( output, From 68b9b487e912c4837318f256e697d0c5f07b461a Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 18 Sep 2024 16:42:28 -0700 Subject: [PATCH 07/22] force the same dtype when comparing FA3 and cuDNN FP8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_fused_attn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index d110dece53..14456010b4 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -1319,6 +1319,8 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol): logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item())) logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item())) try: + if a.dtype != b.dtype: + a = a.to(b.dtype) torch.testing.assert_close(a, b, atol=atol, rtol=rtol) except Exception as e: logging.debug(e) From 0553a8317f538129ee31086824336dd7f5924b16 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 18 Sep 2024 16:43:31 -0700 Subject: [PATCH 08/22] add try/except for FA3 when custom qkv descales are not supported Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 40 ++++++++++++++++++------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 817f4bb62e..4dd70ade4e 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -4828,6 +4828,10 @@ def __init__( self.attention_type = attention_type self.layer_number = 1 if layer_number is None else layer_number self.deterministic = deterministic + self.logger = logging.getLogger("FlashAttention") + self.logger.setLevel(_log_level) + if not self.logger.hasHandlers(): + self.logger.addHandler(_stream_handler) def forward( self, @@ -5067,16 +5071,32 @@ def convert_to_torch_float8(tensor, dtype): convert_to_torch_float8(x, torch_dtype) for x in [query_layer, key_layer, value_layer] ) - output, _ = func( - query_layer, - key_layer, - value_layer, - *fa_optional_forward_args_thd, - softmax_scale=self.softmax_scale, - causal="causal" in attn_mask_type, - deterministic=self.deterministic, - **fa_optional_forward_kwargs_fp8, - ) + try: + output, _ = func( + query_layer, + key_layer, + value_layer, + *fa_optional_forward_args_thd, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + deterministic=self.deterministic, + **fa_optional_forward_kwargs_fp8, + ) + except TypeError: + self.logger.debug( + "Running with default q, k, v descales, i.e. 1s. To enable custom " + "descales, please install flashattn-hopper (FA3) with this PR: " + "https://github.com/Dao-AILab/flash-attention/pull/1210." + ) + output, _ = func( + query_layer, + key_layer, + value_layer, + *fa_optional_forward_args_thd, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + deterministic=self.deterministic, + ) if fp8 and fp8_meta["recipe"].fp8_mha: output = cast_to_fp8( output, From b73760b50df7f57f56644063765df123a4a1baa8 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 18 Sep 2024 17:24:47 -0700 Subject: [PATCH 09/22] replace FA3 installation warning with a debug logging message Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 30 +++++++++++++------------ 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 4dd70ade4e..a204309164 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -85,6 +85,16 @@ from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo from transformer_engine.pytorch.graph import is_graph_capturing +# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 +_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) +# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 +_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) +_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL +_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} +_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2] +_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s") +_stream_handler = logging.StreamHandler() +_stream_handler.setFormatter(_formatter) _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) _NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1")) @@ -105,8 +115,12 @@ _flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.6.1") except PackageNotFoundError: if get_device_compute_capability() == (9, 0) and _NVTE_FLASH_ATTN: - warnings.warn( - "To use flash-attn v3, please use the following commands to install: \n" + logger = logging.getLogger() + logger.setLevel(_log_level) + if not logger.hasHandlers(): + logger.addHandler(_stream_handler) + logger.debug( + "To use flash-attn v3, please follow these steps to install the flashattn-hopper package: \n" """(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" \n""" """(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` \n""" """(3) mkdir -p $python_path/flashattn_hopper \n""" @@ -132,18 +146,6 @@ from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd - -# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 -_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) -# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 -_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) -_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL -_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} -_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2] -_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s") -_stream_handler = logging.StreamHandler() -_stream_handler.setFormatter(_formatter) - _attention_backends = { "attention_params": None, "use_flash_attention": None, From 66cc6f20e26b5245bd20d51bac96917d6eb19d44 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 19 Sep 2024 00:25:35 +0000 Subject: [PATCH 10/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index a204309164..b3648f91af 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -120,7 +120,8 @@ if not logger.hasHandlers(): logger.addHandler(_stream_handler) logger.debug( - "To use flash-attn v3, please follow these steps to install the flashattn-hopper package: \n" + "To use flash-attn v3, please follow these steps to install the flashattn-hopper" + " package: \n" """(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" \n""" """(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` \n""" """(3) mkdir -p $python_path/flashattn_hopper \n""" From 32696859cffa97ca9709ff449449e24aa2f8f238 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 18 Sep 2024 17:35:32 -0700 Subject: [PATCH 11/22] fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index b3648f91af..d02a4f4207 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -115,13 +115,12 @@ _flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.6.1") except PackageNotFoundError: if get_device_compute_capability() == (9, 0) and _NVTE_FLASH_ATTN: - logger = logging.getLogger() - logger.setLevel(_log_level) - if not logger.hasHandlers(): - logger.addHandler(_stream_handler) - logger.debug( - "To use flash-attn v3, please follow these steps to install the flashattn-hopper" - " package: \n" + fa3_logger = logging.getLogger() + fa3_logger.setLevel(_log_level) + if not fa3_logger.hasHandlers(): + fa3_logger.addHandler(_stream_handler) + fa3_logger.debug( + "To use flash-attn v3, please follow these steps to install the flashattn-hopper package: \n" """(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" \n""" """(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` \n""" """(3) mkdir -p $python_path/flashattn_hopper \n""" From 39a4e1d4f0e49b2ec4e7ea7abc3652fbdd09f0ba Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 19 Sep 2024 00:38:39 +0000 Subject: [PATCH 12/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index d02a4f4207..876c549c49 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -120,7 +120,8 @@ if not fa3_logger.hasHandlers(): fa3_logger.addHandler(_stream_handler) fa3_logger.debug( - "To use flash-attn v3, please follow these steps to install the flashattn-hopper package: \n" + "To use flash-attn v3, please follow these steps to install the flashattn-hopper" + " package: \n" """(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" \n""" """(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` \n""" """(3) mkdir -p $python_path/flashattn_hopper \n""" From 336a452a4eb6e72cf79718d444acf69f0b039e25 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 3 Oct 2024 13:34:22 -0700 Subject: [PATCH 13/22] remove unused imports Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index dd746a1f2d..9441aa8452 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -132,12 +132,6 @@ from flashattn_hopper.flash_attn_interface import ( flash_attn_varlen_func as flash_attn_varlen_func_v3, ) - from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import - _flash_attn_forward as _flash_attn_forward_v3, - ) - from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import - _flash_attn_backward as _flash_attn_backward_v3, - ) _use_flash_attn_3 = True From 2e140c522598088e6da6fa618a3d6ca4b692c6e3 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 3 Oct 2024 13:38:21 -0700 Subject: [PATCH 14/22] avoid varlen_func for FP8 and improve messaging Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 65 ++++++++++++++----------- 1 file changed, 36 insertions(+), 29 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 9441aa8452..4fab75c6ef 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -110,9 +110,15 @@ _flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7") _flash_attn_3_plus = False _use_flash_attn_3 = False +_flash_attn_3_installation_steps = """\ +(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" +(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` +(3) mkdir -p $python_path/flashattn_hopper +(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py""" try: _flash_attn_v3_version = PkgVersion(get_pkg_version("flashattn-hopper")) - _flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.6.1") + _flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.9") + _flash_attn_3_0_0_beta = _flash_attn_3_plus and _flash_attn_v3_version < PkgVersion("3.0.0") except PackageNotFoundError: if get_device_compute_capability() == (9, 0) and _NVTE_FLASH_ATTN: fa3_logger = logging.getLogger() @@ -120,12 +126,8 @@ if not fa3_logger.hasHandlers(): fa3_logger.addHandler(_stream_handler) fa3_logger.debug( - "To use flash-attn v3, please follow these steps to install the flashattn-hopper" - " package: \n" - """(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" \n""" - """(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` \n""" - """(3) mkdir -p $python_path/flashattn_hopper \n""" - """(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py""" + "To use flash-attn v3, please follow these steps to install the flashattn-hopper " + "package: \n" + _flash_attn_3_installation_steps ) else: from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3 @@ -433,6 +435,16 @@ def get_attention_backend( "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" ) use_flash_attention = False + if ( + use_flash_attention + and _use_flash_attn_3 + and fp8 + and fp8_meta["recipe"].fp8_dpa + ): + logger.debug( + "Disabling FlashAttention 3 for FP8 and qkv_format = thd" + ) + _use_flash_attn_3 = False # Filter: Dropout if attention_dropout != 0.0 and use_flash_attention: @@ -5018,7 +5030,7 @@ def forward( if _flash_attn_2_4_1_plus: fa_optional_forward_kwargs["deterministic"] = self.deterministic fa_optional_forward_args_thd = [] - if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type: + if (qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type) or fp8: func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3 else: if _flash_attn_2_5_7_plus: @@ -5033,7 +5045,8 @@ def forward( fa_optional_forward_args_thd.append(max_seqlen_q) fa_optional_forward_args_thd.append(max_seqlen_kv) if _use_flash_attn_3: - fa_optional_forward_kwargs_fp8 = {} + fa_3_optional_forward_kwargs = {} + fa_3_optional_forward_kwargs["deterministic"] = self.deterministic if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) activation_dtype = query_layer.dtype @@ -5060,9 +5073,9 @@ def convert_to_torch_float8(tensor, dtype): Float8Tensor.to_float8(x, fp8_dtype=fp8_dtype_forward) for x in [query_layer, key_layer, value_layer] ) - fa_optional_forward_kwargs_fp8["descale_q"] = query_layer._scale_inv - fa_optional_forward_kwargs_fp8["descale_k"] = key_layer._scale_inv - fa_optional_forward_kwargs_fp8["descale_v"] = value_layer._scale_inv + fa_3_optional_forward_kwargs["descale_q"] = query_layer._scale_inv + fa_3_optional_forward_kwargs["descale_k"] = key_layer._scale_inv + fa_3_optional_forward_kwargs["descale_v"] = value_layer._scale_inv query_layer, key_layer, value_layer = ( convert_to_torch_float8(x, torch_dtype) for x in [query_layer, key_layer, value_layer] @@ -5075,24 +5088,18 @@ def convert_to_torch_float8(tensor, dtype): *fa_optional_forward_args_thd, softmax_scale=self.softmax_scale, causal="causal" in attn_mask_type, - deterministic=self.deterministic, - **fa_optional_forward_kwargs_fp8, - ) - except TypeError: - self.logger.debug( - "Running with default q, k, v descales, i.e. 1s. To enable custom " - "descales, please install flashattn-hopper (FA3) with this PR: " - "https://github.com/Dao-AILab/flash-attention/pull/1210." - ) - output, _ = func( - query_layer, - key_layer, - value_layer, - *fa_optional_forward_args_thd, - softmax_scale=self.softmax_scale, - causal="causal" in attn_mask_type, - deterministic=self.deterministic, + **fa_3_optional_forward_kwargs, ) + except TypeError as e: + if _flash_attn_3_0_0_beta: + e.args = ( + e.args[0] + + ". Please update your FlashAttention 3 (beta) installation as it " + + "may have added more supported arguments to its API. \n" + + _flash_attn_3_installation_steps, + ) + e.args[1:] + raise + if fp8 and fp8_meta["recipe"].fp8_mha: output = cast_to_fp8( output, From 1138edf4b71c186699bc9612eca8168c9e0d5751 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Oct 2024 20:39:12 +0000 Subject: [PATCH 15/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/attention.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 4fab75c6ef..2edbf0b405 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -127,7 +127,8 @@ fa3_logger.addHandler(_stream_handler) fa3_logger.debug( "To use flash-attn v3, please follow these steps to install the flashattn-hopper " - "package: \n" + _flash_attn_3_installation_steps + "package: \n" + + _flash_attn_3_installation_steps ) else: from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3 @@ -435,15 +436,8 @@ def get_attention_backend( "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" ) use_flash_attention = False - if ( - use_flash_attention - and _use_flash_attn_3 - and fp8 - and fp8_meta["recipe"].fp8_dpa - ): - logger.debug( - "Disabling FlashAttention 3 for FP8 and qkv_format = thd" - ) + if use_flash_attention and _use_flash_attn_3 and fp8 and fp8_meta["recipe"].fp8_dpa: + logger.debug("Disabling FlashAttention 3 for FP8 and qkv_format = thd") _use_flash_attn_3 = False # Filter: Dropout @@ -5097,7 +5091,7 @@ def convert_to_torch_float8(tensor, dtype): + ". Please update your FlashAttention 3 (beta) installation as it " + "may have added more supported arguments to its API. \n" + _flash_attn_3_installation_steps, - ) + e.args[1:] + ) + e.args[1:] raise if fp8 and fp8_meta["recipe"].fp8_mha: From 4095be8c4bac03ed8ca6a4e6454b29504cda55ac Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 3 Oct 2024 13:40:23 -0700 Subject: [PATCH 16/22] add SWA support for FA3 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 2edbf0b405..6aa68f55ca 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -636,15 +636,6 @@ def get_attention_backend( attn_mask_type, ) use_fused_attention = False - if ( - use_flash_attention - and (window_size[0] != -1 or window_size[1] not in [-1, 0]) - and _flash_attn_3_plus - ): - logger.debug( - "Disabling FlashAttention 3 as it does not support sliding window attention" - ) - _use_flash_attn_3 = False if ( use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]) @@ -5040,6 +5031,7 @@ def forward( fa_optional_forward_args_thd.append(max_seqlen_kv) if _use_flash_attn_3: fa_3_optional_forward_kwargs = {} + fa_3_optional_forward_kwargs["window_size"] = window_size fa_3_optional_forward_kwargs["deterministic"] = self.deterministic if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) From 8e2bcc25f79d84cb9090c9561d130df14d39017d Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 3 Oct 2024 13:45:09 -0700 Subject: [PATCH 17/22] fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6aa68f55ca..c307302d55 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -127,8 +127,8 @@ fa3_logger.addHandler(_stream_handler) fa3_logger.debug( "To use flash-attn v3, please follow these steps to install the flashattn-hopper " - "package: \n" - + _flash_attn_3_installation_steps + "package: \n%s", + _flash_attn_3_installation_steps ) else: from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3 From 7bf4936edd0852f9e85caa846faa370ab76b930e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Oct 2024 20:46:11 +0000 Subject: [PATCH 18/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index c307302d55..9a6e27ec2b 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -128,7 +128,7 @@ fa3_logger.debug( "To use flash-attn v3, please follow these steps to install the flashattn-hopper " "package: \n%s", - _flash_attn_3_installation_steps + _flash_attn_3_installation_steps, ) else: from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3 From b765f3d5c042b7adc79fdff4a29a66fbd20c68c1 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sun, 6 Oct 2024 11:56:46 -0700 Subject: [PATCH 19/22] change preference reason for FP8 logic Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 9a6e27ec2b..9b9a6a860b 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -821,10 +821,6 @@ def get_attention_backend( "for performance reasons" ) use_flash_attention = False - - # Select FusedAttention for FP8 - # FA3 uses default scaling factors (i.e. 1) in FP8 execution, while FusedAttention takes - # scaling factors from `fp8_meta` and offers more accurate quantization/de-quantization if ( use_flash_attention and use_fused_attention @@ -832,8 +828,8 @@ def get_attention_backend( and _use_flash_attn_3 ): logger.debug( - "Disabling FlashAttention 3 to give FusedAttention preference as FusedAttention " - "supports more accurate scaling factors in FP8 execution" + "Disabling FlashAttention 3 to give FusedAttention preference for performance reasons " + "in FP8 execution" ) use_flash_attention = False From a4030e85a62f99c69d9e375c9245eacf79289929 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 7 Oct 2024 16:39:58 -0700 Subject: [PATCH 20/22] minor fixes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 31 ++++++++++++------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 9b9a6a860b..763a8b6c5d 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -348,7 +348,7 @@ def get_attention_backend( use_fused_attention = False # Filter: Compute capability - global _flash_attn_3_plus, _use_flash_attn_3 + global _use_flash_attn_3 if device_compute_capability < (8, 0): if use_flash_attention: logger.debug("Disabling FlashAttention as it requires compute capability sm80+") @@ -357,7 +357,7 @@ def get_attention_backend( logger.debug("Disabling FusedAttention as it requires compute capability sm80+") use_fused_attention = False if device_compute_capability < (9, 0): - if use_flash_attention and _flash_attn_3_plus: + if use_flash_attention and _use_flash_attn_3: logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+") _use_flash_attn_3 = False @@ -436,15 +436,11 @@ def get_attention_backend( "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" ) use_flash_attention = False - if use_flash_attention and _use_flash_attn_3 and fp8 and fp8_meta["recipe"].fp8_dpa: - logger.debug("Disabling FlashAttention 3 for FP8 and qkv_format = thd") - _use_flash_attn_3 = False # Filter: Dropout - if attention_dropout != 0.0 and use_flash_attention: - if _flash_attn_3_plus and _use_flash_attn_3: - logger.debug("Disabling FlashAttention 3 for dropout") - _use_flash_attn_3 = False + if attention_dropout != 0.0 and use_flash_attention and _use_flash_attn_3: + logger.debug("Disabling FlashAttention 3 for dropout") + _use_flash_attn_3 = False # Filter: Context parallelism # qkv_format | attn_mask_type | attn_bias_type | supported backends @@ -464,7 +460,7 @@ def get_attention_backend( ) use_unfused_attention = False if context_parallel and use_flash_attention: - if _flash_attn_3_plus and _use_flash_attn_3: + if _use_flash_attn_3: logger.debug("Disabling FlashAttention 3 for context parallelism") _use_flash_attn_3 = False if fp8 and fp8_meta["recipe"].fp8_dpa: @@ -559,7 +555,7 @@ def get_attention_backend( use_fused_attention = False if ( use_flash_attention - and _flash_attn_3_plus + and _use_flash_attn_3 and attn_mask_type in ["causal", "padding_causal"] and max_seqlen_q != max_seqlen_kv ): @@ -593,6 +589,9 @@ def get_attention_backend( "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" ) use_flash_attention = False + if use_flash_attention and _use_flash_attn_3 and fp8 and fp8_meta["recipe"].fp8_dpa and "padding" in attn_mask_type: + logger.debug("Disabling FlashAttention 3 for FP8 and padding masks") + _use_flash_attn_3 = False # Filter: Sliding window attention # backend | window_size | diagonal alignment @@ -656,12 +655,12 @@ def get_attention_backend( # UnfusedDotProductAttention | no_bias, pre/post_scale_bias | # | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias if use_flash_attention and core_attention_bias_type == "alibi": - if _flash_attn_3_plus and _use_flash_attn_3: + if _use_flash_attn_3: logger.debug("Disabling FlashAttention 3 for ALiBi") _use_flash_attn_3 = False - if not _flash_attn_2_4_plus: - logger.debug("Disabling FlashAttention for ALiBi") - use_flash_attention = False + elif not _flash_attn_2_4_plus: + logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+") + use_flash_attention = False if use_flash_attention and ( core_attention_bias_type not in ["no_bias", "alibi"] @@ -5011,7 +5010,7 @@ def forward( if _flash_attn_2_4_1_plus: fa_optional_forward_kwargs["deterministic"] = self.deterministic fa_optional_forward_args_thd = [] - if (qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type) or fp8: + if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type: func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3 else: if _flash_attn_2_5_7_plus: From 569532a1c2b72ad7284673ab7e1536652151570d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 7 Oct 2024 23:40:26 +0000 Subject: [PATCH 21/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/attention.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 763a8b6c5d..2c3be0d503 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -589,7 +589,13 @@ def get_attention_backend( "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" ) use_flash_attention = False - if use_flash_attention and _use_flash_attn_3 and fp8 and fp8_meta["recipe"].fp8_dpa and "padding" in attn_mask_type: + if ( + use_flash_attention + and _use_flash_attn_3 + and fp8 + and fp8_meta["recipe"].fp8_dpa + and "padding" in attn_mask_type + ): logger.debug("Disabling FlashAttention 3 for FP8 and padding masks") _use_flash_attn_3 = False From f006a25d6415927f7f624571d4ee17bc8882ea3b Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 7 Oct 2024 16:49:31 -0700 Subject: [PATCH 22/22] minor fix Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- qa/L0_pytorch_unittest/test.sh | 2 +- transformer_engine/pytorch/attention.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index b69aed6648..7b21b997cd 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -13,7 +13,7 @@ pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py -NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py +NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 67225fd900..7aa8f0def1 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -664,7 +664,7 @@ def get_attention_backend( if _use_flash_attn_3: logger.debug("Disabling FlashAttention 3 for ALiBi") _use_flash_attn_3 = False - elif not _flash_attn_2_4_plus: + if not _use_flash_attn_3 and not _flash_attn_2_4_plus: logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+") use_flash_attention = False