From 1280f59c1a65e50d4e174e4195e14f173301a497 Mon Sep 17 00:00:00 2001 From: billishyahao Date: Wed, 28 Aug 2024 01:22:20 +0800 Subject: [PATCH] [Bug] Fix crash when logging optimizer state to tb (#417) --- megatron/training.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/megatron/training.py b/megatron/training.py index 79f39ccc2e..0aeaabeba5 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -1032,6 +1032,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, if args.log_optimizer_states_to_tensorboard and optimizer is not None: opt_stats = [0.0] * 8 opt_stats_2 = [0.0] * 4 + + #TODO(billishyahao): Remove me after bf16_optimizer promotes its state. + if not hasattr(optimizer, "state"): + assert hasattr(optimizer, "optimizer"), f"Optimizer must have optimizer property." + optimizer.state = optimizer.optimizer.state + for _, group in enumerate(optimizer.param_groups): for _, param in enumerate(group['params']): opt_stats[0] += (torch.norm(optimizer.state[param]['exp_avg_sq']).item())**2