Skip to content

Commit

Permalink
Debug test failures
Browse files Browse the repository at this point in the history
Signed-off-by: Tim Moon <tmoon@nvidia.com>
  • Loading branch information
timmoon10 committed Sep 27, 2024
1 parent d54cb00 commit b1db2ed
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 7 deletions.
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:
def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""Disable native AMP for bias_gelu_fused_"""
with torch.cuda.amp.autocast(enabled=False):
if bias.numel() != 0:
if bias is not None and bias.numel() != 0:
return bias_gelu_fused_(inp, bias)
return gelu_fused_(inp)

Expand All @@ -119,7 +119,7 @@ def bgrad_dgelu_fused(
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
"""Disable native AMP for `bgrad_dgelu_fused_`"""
with torch.cuda.amp.autocast(enabled=False):
if bias.numel() != 0:
if bias is not None and bias.numel() != 0:
return bgrad_dgelu_fused_(grad_output, inp, bias)
return None, dgelu_fused_(grad_output, inp)

Expand Down
15 changes: 14 additions & 1 deletion transformer_engine/pytorch/module/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,21 @@ def reset_parameters(self, defer_init=False) -> None:
@no_torch_dynamo()
def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""LayerNorm FWD"""

# Set the activation type for AMP.
TransformerEngineBaseModule.set_activation_dtype(self, inp)
# Note: This will soon be deprecated with
# https://github.com/NVIDIA/TransformerEngine/pull/1033
if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype()
elif self.activation_dtype != inp.dtype:
dtype = inp.dtype
for name, param in self.named_parameters():
if param is not None:
assert dtype == param.dtype, (
"Data types for parameters must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
)
self.activation_dtype = dtype

if torch.is_grad_enabled():
fwd_fn = _LayerNorm.apply
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,8 @@ def forward(
ln_weight.weight_offloading = True
fc1_weight.weight_offloading = True
fc2_weight.weight_offloading = True
fc1_bias.weight_offloading = True
if fc1_bias is not None:
fc1_bias.weight_offloading = True

inputmat.activation_offloading = True
if normalization == "LayerNorm":
Expand Down
14 changes: 13 additions & 1 deletion transformer_engine/pytorch/module/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,19 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""RMSNorm FWD"""

# Set the activation type for AMP.
TransformerEngineBaseModule.set_activation_dtype(self, inp)
# Note: This will soon be deprecated with
# https://github.com/NVIDIA/TransformerEngine/pull/1033
if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype()
elif self.activation_dtype != inp.dtype:
dtype = inp.dtype
for name, param in self.named_parameters():
if param is not None:
assert dtype == param.dtype, (
"Data types for parameters must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
)
self.activation_dtype = dtype

if torch.is_grad_enabled():
fwd_fn = _RMSNorm.apply
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def forward(
return output

def _bias_dropout_add(self, hidden_state, bias, residual, drop_path=None):
if drop_path is None and bias.numel() != 0:
if drop_path is None and bias is not None and bias.numel() != 0:
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
Expand All @@ -763,7 +763,7 @@ def _bias_dropout_add(self, hidden_state, bias, residual, drop_path=None):
with self.bias_dropout_add_exec_handler():
output = bias_dropout_add_func(hidden_state, bias, residual, self.hidden_dropout)
else:
if bias.numel() != 0:
if bias is not None and bias.numel() != 0:
hidden_state = hidden_state + bias
out = torch.nn.functional.dropout(
hidden_state, p=self.hidden_dropout, training=self.training
Expand Down

0 comments on commit b1db2ed

Please sign in to comment.