Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Miscellaneous fixes for FA3 attention #1174

Merged
merged 29 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
bcdc4d1
add qkv descales to FA3
cyanguwa Sep 10, 2024
3ed49d0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2024
1db61e2
fix sbhd shapes
cyanguwa Sep 17, 2024
6a86660
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 17, 2024
7da4b6c
Merge branch 'main' into add_descales
cyanguwa Sep 17, 2024
de3db0a
Merge branch 'main' into add_descales
cyanguwa Sep 18, 2024
19e7f87
force the same dtype when comparing FA3 and cuDNN FP8
cyanguwa Sep 18, 2024
bff80b6
Revert "force the same dtype when comparing FA3 and cuDNN FP8"
cyanguwa Sep 18, 2024
68b9b48
force the same dtype when comparing FA3 and cuDNN FP8
cyanguwa Sep 18, 2024
0553a83
add try/except for FA3 when custom qkv descales are not supported
cyanguwa Sep 18, 2024
b73760b
replace FA3 installation warning with a debug logging message
cyanguwa Sep 19, 2024
66cc6f2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 19, 2024
3269685
fix lint
cyanguwa Sep 19, 2024
39a4e1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 19, 2024
5bcc355
Merge branch 'main' into add_descales
cyanguwa Sep 19, 2024
3dbee25
Merge branch 'main' into add_descales
cyanguwa Sep 27, 2024
c01a5b2
Merge branch 'NVIDIA:main' into add_descales
cyanguwa Oct 1, 2024
12dc8a9
Merge branch 'main' into add_descales
cyanguwa Oct 3, 2024
336a452
remove unused imports
cyanguwa Oct 3, 2024
2e140c5
avoid varlen_func for FP8 and improve messaging
cyanguwa Oct 3, 2024
1138edf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 3, 2024
4095be8
add SWA support for FA3
cyanguwa Oct 3, 2024
8e2bcc2
fix lint
cyanguwa Oct 3, 2024
7bf4936
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 3, 2024
b765f3d
change preference reason for FP8 logic
cyanguwa Oct 6, 2024
a4030e8
minor fixes
cyanguwa Oct 7, 2024
569532a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 7, 2024
e907ad7
Merge branch 'main' into add_descales
cyanguwa Oct 7, 2024
f006a25
minor fix
cyanguwa Oct 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,8 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item()))
logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item()))
try:
if a.dtype != b.dtype:
a = a.to(b.dtype)
torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
except Exception as e:
logging.debug(e)
Expand Down
108 changes: 78 additions & 30 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
from transformer_engine.pytorch.graph import is_graph_capturing

# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2]
_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
_stream_handler = logging.StreamHandler()
_stream_handler.setFormatter(_formatter)

_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
Expand All @@ -105,8 +115,13 @@
_flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.6.1")
except PackageNotFoundError:
if get_device_compute_capability() == (9, 0) and _NVTE_FLASH_ATTN:
warnings.warn(
"To use flash-attn v3, please use the following commands to install: \n"
fa3_logger = logging.getLogger()
fa3_logger.setLevel(_log_level)
if not fa3_logger.hasHandlers():
fa3_logger.addHandler(_stream_handler)
fa3_logger.debug(
"To use flash-attn v3, please follow these steps to install the flashattn-hopper"
" package: \n"
"""(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" \n"""
"""(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` \n"""
"""(3) mkdir -p $python_path/flashattn_hopper \n"""
Expand All @@ -132,18 +147,6 @@
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd


# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2]
_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
_stream_handler = logging.StreamHandler()
_stream_handler.setFormatter(_formatter)

_attention_backends = {
"attention_params": None,
"use_flash_attention": None,
Expand Down Expand Up @@ -4828,6 +4831,10 @@ def __init__(
self.attention_type = attention_type
self.layer_number = 1 if layer_number is None else layer_number
self.deterministic = deterministic
self.logger = logging.getLogger("FlashAttention")
self.logger.setLevel(_log_level)
if not self.logger.hasHandlers():
self.logger.addHandler(_stream_handler)

def forward(
self,
Expand Down Expand Up @@ -4893,6 +4900,10 @@ def forward(
x.transpose(0, 1)
for x in (query_layer._data, key_layer._data, value_layer._data)
]
query_layer, key_layer, value_layer = [
Float8Tensor.make_like(x, data=x._data)
for x in (query_layer, key_layer, value_layer)
]
if context_parallel:
query_layer._data, key_layer._data, value_layer._data = [
x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data)
Expand Down Expand Up @@ -5028,33 +5039,66 @@ def forward(
fa_optional_forward_args_thd.append(max_seqlen_q)
fa_optional_forward_args_thd.append(max_seqlen_kv)
if _use_flash_attn_3:
fa_optional_forward_kwargs_fp8 = {}
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
activation_dtype = query_layer.dtype
torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True)

def convert_to_torch_float8(tensor, dtype):
out = torch.Tensor().to(device=tensor.device, dtype=dtype)
out.set_(
tensor._data.untyped_storage(),
tensor._data.storage_offset(),
tensor._data.shape,
tensor._data.stride(),
)
return out

if fp8_meta["recipe"].fp8_mha:
assert all(
isinstance(x, Float8Tensor)
for x in [query_layer, key_layer, value_layer]
), "q/k/v must be Float8Tensors for FP8 MHA."
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = query_layer._scale_inv
query_layer, key_layer, value_layer = (
x.to(activation_dtype).to(torch_dtype)
for x in [query_layer, key_layer, value_layer]
)
else:
query_layer, key_layer, value_layer = (
x.to(torch_dtype) for x in [query_layer, key_layer, value_layer]
Float8Tensor.to_float8(x, fp8_dtype=fp8_dtype_forward)
for x in [query_layer, key_layer, value_layer]
)
output, _ = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
deterministic=self.deterministic,
)
fa_optional_forward_kwargs_fp8["descale_q"] = query_layer._scale_inv
fa_optional_forward_kwargs_fp8["descale_k"] = key_layer._scale_inv
fa_optional_forward_kwargs_fp8["descale_v"] = value_layer._scale_inv
query_layer, key_layer, value_layer = (
convert_to_torch_float8(x, torch_dtype)
for x in [query_layer, key_layer, value_layer]
)
try:
output, _ = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
deterministic=self.deterministic,
**fa_optional_forward_kwargs_fp8,
)
except TypeError:
self.logger.debug(
"Running with default q, k, v descales, i.e. 1s. To enable custom "
"descales, please install flashattn-hopper (FA3) with this PR: "
"https://github.com/Dao-AILab/flash-attention/pull/1210."
)
output, _ = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
deterministic=self.deterministic,
)
if fp8 and fp8_meta["recipe"].fp8_mha:
output = cast_to_fp8(
output,
Expand Down Expand Up @@ -5088,8 +5132,12 @@ def forward(
if qkv_format == "sbhd":
# (bs)hd -> bs(hd) -> sb(hd)
if fp8 and fp8_meta["recipe"].fp8_mha:
output.reshape(batch_size * max_seqlen_q // cp_size, -1).transpose_2d()
output = output.reshape(batch_size, max_seqlen_q // cp_size, -1)
output = Float8Tensor.make_like(
output,
data=output._data.reshape(batch_size, max_seqlen_q // cp_size, -1)
.transpose(0, 1)
.contiguous(),
cyanguwa marked this conversation as resolved.
Show resolved Hide resolved
)
else:
output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1)
elif qkv_format == "bshd":
Expand Down
Loading