diff --git a/megatron/training.py b/megatron/training.py index 6ba26f3944..467530837b 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -883,12 +883,11 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, if not skipped_iter: total_loss_dict[key] = total_loss_dict.get( key, get_accelerator().FloatTensor([0.0])) + loss_dict[key] - else: - value = loss_dict[key].float().sum().item() - is_nan = value == float('inf') or \ - value == -float('inf') or \ - value != value - got_nan = got_nan or is_nan + value = loss_dict[key].float().sum().item() + is_nan = value == float('inf') or \ + value == -float('inf') or \ + value != value + got_nan = got_nan or is_nan total_loss_dict[nan_iters_key] = total_loss_dict.get( nan_iters_key, 0) + int(got_nan)