Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Improve logging/messaging in attention #1074

Merged
merged 8 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 22 additions & 38 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,12 @@
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
format="[%(levelname)-8s | %(name)-19s]: %(message)s",
level=log_levels[log_level if log_level in [0, 1, 2] else 2],
)
_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2]
_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
_stream_handler = logging.StreamHandler()
_stream_handler.setFormatter(_formatter)

_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
Expand Down Expand Up @@ -262,6 +262,9 @@ def get_attention_backend(

# Run config
logger = logging.getLogger("DotProductAttention")
logger.setLevel(_log_level)
if not logger.hasHandlers():
logger.addHandler(_stream_handler)
device_compute_capability = get_device_compute_capability()
cudnn_version = get_cudnn_version()
run_config = {
Expand Down Expand Up @@ -3217,31 +3220,28 @@ def check_set_window_size(
"""
orig_window_size = window_size
if "causal" in attn_mask_type:
if orig_window_size is None or (
orig_window_size[0] == -1 and orig_window_size[1] in [-1, 0]
):
if orig_window_size is None:
window_size = (-1, 0)
warnings.warn(
"window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
)
elif orig_window_size[0] >= 0:
elif orig_window_size == (-1, -1) or (
orig_window_size[0] >= 0 and orig_window_size[1] != 0
):
window_size = (orig_window_size[0], 0)
warnings.warn(
"window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
)
else:
elif orig_window_size != (-1, 0) and (orig_window_size[0] < 0 or orig_window_size[1] != 0):
assert False, (
"window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
)
elif attn_mask_type in ["no_mask", "padding", "arbitrary"]:
if orig_window_size is None or (
orig_window_size[0] == -1 and orig_window_size[1] in [-1, 0]
):
if orig_window_size is None:
window_size = (-1, -1)
elif orig_window_size == (-1, 0):
window_size = (-1, -1)
warnings.warn(
"window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
)
elif orig_window_size[0] < 0 or orig_window_size[1] < 0:
elif orig_window_size != (-1, -1) and (orig_window_size[0] < 0 or orig_window_size[1] < 0):
assert False, (
"window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
)
Expand Down Expand Up @@ -3541,9 +3541,7 @@ def forward(
fp8_meta,
deterministic,
):
logger = logging.getLogger("FusedAttnFunc_qkvpacked")
if fp8:
logger.debug("Running forward in FP8")
if fp8_meta["recipe"].fp8_mha:
assert isinstance(qkv, Float8Tensor), "qkv must be Float8Tensors for FP8 MHA."
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv
Expand Down Expand Up @@ -3627,7 +3625,6 @@ def forward(
fp8_meta["scaling_fwd"].scale_inv.clone(),
)
else:
logger.debug("Running forward in %s", qkv.dtype)
out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
is_training,
max_seqlen,
Expand Down Expand Up @@ -3680,7 +3677,6 @@ def forward(

@staticmethod
def backward(ctx, d_out):
logger = logging.getLogger("FusedAttnFunc_qkvpacked")
if ctx.fp8_meta["recipe"].fp8_mha:
assert isinstance(
d_out, Float8Tensor
Expand Down Expand Up @@ -3734,7 +3730,6 @@ def backward(ctx, d_out):
else:
with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"):
if ctx.fp8:
logger.debug("Running backward in FP8")
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
Expand Down Expand Up @@ -3800,7 +3795,6 @@ def backward(ctx, d_out):
ctx.qkv_dtype,
).view(dqkv_fp8.shape)
else:
logger.debug("Running backward in %s", qkv.dtype)
if d_out.dtype == torch.uint8:
d_out = d_out_f8tensor.from_float8(qkv.dtype)
dqkv, *rest = fused_attn_bwd_qkvpacked(
Expand Down Expand Up @@ -3918,9 +3912,7 @@ def forward(
fp8_meta,
deterministic,
):
logger = logging.getLogger("FusedAttnFunc_kvpacked")
if fp8:
logger.debug("Running forward in FP8")
if fp8_meta["recipe"].fp8_mha:
assert isinstance(q, Float8Tensor) and isinstance(
kv, Float8Tensor
Expand Down Expand Up @@ -4017,7 +4009,6 @@ def forward(
fp8_meta["scaling_fwd"].scale_inv.clone(),
)
else:
logger.debug("Running forward in %s", q.dtype)
out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked(
is_training,
max_seqlen_q,
Expand Down Expand Up @@ -4081,7 +4072,6 @@ def forward(

@staticmethod
def backward(ctx, d_out):
logger = logging.getLogger("FusedAttnFunc_kvpacked")
if ctx.fp8_meta["recipe"].fp8_mha:
assert isinstance(
d_out, Float8Tensor
Expand Down Expand Up @@ -4139,7 +4129,6 @@ def backward(ctx, d_out):
else:
with torch.cuda.nvtx.range("_FusedAttn_kvpacked"):
if ctx.fp8:
logger.debug("Running backward in FP8")
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
Expand Down Expand Up @@ -4224,7 +4213,6 @@ def backward(ctx, d_out):
ctx.qkv_dtype,
).view(dkv_fp8.shape)
else:
logger.debug("Running backward in %s", q.dtype)
if d_out.dtype == torch.uint8:
d_out = d_out_f8tensor.from_float8(q.dtype)
dq, dkv, *rest = fused_attn_bwd_kvpacked(
Expand Down Expand Up @@ -4355,9 +4343,7 @@ def forward(
fp8_meta,
deterministic,
):
logger = logging.getLogger("FusedAttnFunc")
if fp8:
logger.debug("Running forward in FP8")
fused_attention_backend = FusedAttnBackend["FP8"]
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if fp8_meta["recipe"].fp8_mha:
Expand Down Expand Up @@ -4525,7 +4511,6 @@ def forward(
fp8_meta["scaling_fwd"].scale_inv.clone(),
)
else:
logger.debug("Running forward in %s", q.dtype)
out_ret, aux_ctx_tensors = fused_attn_fwd(
is_training,
max_seqlen_q,
Expand Down Expand Up @@ -4599,7 +4584,6 @@ def forward(

@staticmethod
def backward(ctx, d_out):
logger = logging.getLogger("FusedAttnFunc")
if ctx.fp8_meta["recipe"].fp8_mha:
assert isinstance(
d_out, Float8Tensor
Expand Down Expand Up @@ -4661,7 +4645,6 @@ def backward(ctx, d_out):
else:
with torch.cuda.nvtx.range("_FusedAttn"):
if ctx.fp8:
logger.debug("Running backward in FP8")
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
Expand Down Expand Up @@ -4799,7 +4782,6 @@ def backward(ctx, d_out):
ctx.qkv_dtype,
).view(dv_fp8.shape)
else:
logger.debug("Running backward in %s", q.dtype)
if d_out.dtype == torch.uint8:
d_out = d_out_f8tensor.from_float8(q.dtype)
dq, dk, dv, *rest = fused_attn_bwd(
Expand Down Expand Up @@ -4940,7 +4922,6 @@ def __init__(
) -> None:
super().__init__()

self.logger = logging.getLogger("FusedAttention")
self.softmax_scale = softmax_scale
self.attention_dropout = attention_dropout
self.attention_dropout_ctx = attention_dropout_ctx
Expand Down Expand Up @@ -5284,6 +5265,9 @@ def __init__(
super().__init__()

self.logger = logging.getLogger("DotProductAttention")
self.logger.setLevel(_log_level)
if not self.logger.hasHandlers():
self.logger.addHandler(_stream_handler)
self.qkv_format = qkv_format
attn_mask_type = attn_mask_type.replace(",", "_")
if attn_mask_type == "causal_padding":
Expand Down Expand Up @@ -5606,7 +5590,7 @@ def forward(
if self.fp8_meta["recipe"].fp8_mha:
if not self.fp8_meta["recipe"].fp8_dpa:
self.fp8_meta["recipe"].fp8_dpa = True
self.logger.WARNING(
self.logger.warning(
"""Forcing fp8_meta["recipe"].fp8_dpa=True due to """
"""fp8_meta["recipe"].fp8_mha=True"""
)
Expand Down
23 changes: 0 additions & 23 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
# See LICENSE for license information.

"""GroupedLinear API"""
import os
import logging
from typing import Union, Optional, Callable, Tuple, List, Dict, Any

import torch
Expand Down Expand Up @@ -45,17 +43,6 @@
from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor

# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
format="[%(levelname)-8s | %(name)-19s]: %(message)s",
level=log_levels[log_level if log_level in [0, 1, 2] else 2],
)

__all__ = ["GroupedLinear"]

"""
Expand Down Expand Up @@ -97,7 +84,6 @@ def forward(
is_grad_enabled: bool,
*weights_and_biases: Union[Float8Tensor, torch.Tensor, None],
) -> torch.Tensor:
logger = logging.getLogger("GroupedLinear")
num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms]
weights_fp8 = weights_and_biases[num_gemms : 2 * num_gemms]
Expand Down Expand Up @@ -151,8 +137,6 @@ def forward(
inputmats = inputmats_no_fp8

if fp8:
logger.debug("Running forward in FP8")

bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases

Expand Down Expand Up @@ -184,8 +168,6 @@ def forward(
use_split_accumulator=_2X_ACC_FPROP,
)
else:
logger.debug("Running forward in %s", activation_dtype)

# Cast for native AMP
weights = [cast_if_needed(w, activation_dtype) for w in weights]
biases = (
Expand Down Expand Up @@ -286,8 +268,6 @@ def forward(

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
logger = logging.getLogger("GroupedLinear")

with torch.cuda.nvtx.range("_GroupedLinear_backward"):
(
fwd_scale_inverses,
Expand Down Expand Up @@ -353,7 +333,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],

if ctx.requires_dgrad:
if ctx.fp8:
logger.debug("Running backward in FP8")
dgrad = torch.empty(
(sum(ctx.m_splits), weights_fp8[i].size(1)),
dtype=ctx.activation_dtype,
Expand All @@ -376,8 +355,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
logger.debug("Running backward in %s", ctx.activation_dtype)

dgrad = torch.empty(
(sum(ctx.m_splits), weights[0].size(1)),
dtype=ctx.activation_dtype,
Expand Down
22 changes: 0 additions & 22 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"""LayerNormLinear API"""
import os
import warnings
import logging
from typing import Any, Callable, Dict, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -48,17 +47,6 @@
from ._common import _apply_normalization, _noop_cat
from ..float8_tensor import Float8Tensor

# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
format="[%(levelname)-8s | %(name)-19s]: %(message)s",
level=log_levels[log_level if log_level in [0, 1, 2] else 2],
)

__all__ = ["LayerNormLinear"]


Expand Down Expand Up @@ -104,7 +92,6 @@ def forward(
ub_name: str,
fsdp_group: Union[dist_group_type, None],
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
logger = logging.getLogger("LayerNormLinear")
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible"
Expand Down Expand Up @@ -203,8 +190,6 @@ def forward(
ln_out = ln_out_total

if fp8:
logger.debug("Running forward in FP8")

bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias

Expand Down Expand Up @@ -259,8 +244,6 @@ def forward(
dtype=activation_dtype,
)
else:
logger.debug("Running forward in %s", activation_dtype)

# Cast for native AMP
weight = cast_if_needed(weight, activation_dtype)
bias = cast_if_needed(bias, activation_dtype) if use_bias else bias
Expand Down Expand Up @@ -379,7 +362,6 @@ def forward(
def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
logger = logging.getLogger("LayerNormLinear")
if isinstance(grad_outputs[0], Float8Tensor):
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[
0
Expand Down Expand Up @@ -500,8 +482,6 @@ def backward(
ub_obj = None

if ctx.fp8:
logger.debug("Running backward in FP8")

fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
out_index, meta_tensor, out_te_type, out_type = (
Expand Down Expand Up @@ -544,8 +524,6 @@ def backward(
)
clear_tensor_data(grad_output_c)
else:
logger.debug("Running backward in %s", ctx.activation_dtype)

# DGRAD: Evaluated unconditionally to feed into Linear backward
_, _, _ = tex.gemm(
weight,
Expand Down
Loading
Loading