diff --git a/megatron/optimizer/distrib_optimizer.py b/megatron/optimizer/distrib_optimizer.py index 96786394ae..7a53e24b11 100644 --- a/megatron/optimizer/distrib_optimizer.py +++ b/megatron/optimizer/distrib_optimizer.py @@ -334,7 +334,7 @@ def build_model_and_main_param_groups(cls, 'torch.cuda.FloatTensor, ' 'torch.cuda.HalfTensor, or ' 'torch.cuda.BFloat16Tensor. ' - 'Received {}'.format(param.type())) + 'Received {}'.format(model_param.type())) # Update optimizer's params. group_range["orig_group"]["params"] = [ @@ -386,7 +386,7 @@ def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, self.model_param_group_index_map, self.opt_group_ranges = \ self.build_optimizer_group_ranges(self.optimizer.param_groups, self.model_gbuf_ranges) - + # Allocate main param shards. ( self.model_float16_groups, @@ -630,7 +630,7 @@ def save_parameter_state(self, filename): # Gather contiguous shards on DP rank 0. world_tensors = {} for key, send_tensor in local_shards.items(): - + # Gather tensor list. if data_parallel_rank == 0: recv_tensors = [torch.empty((gbuf_local_numel,), @@ -700,7 +700,7 @@ def load_parameter_state(self, filename): # Scatter local shards from DP rank 0. for key, recv_tensor in local_shards.items(): - + # Scatter tensor list. if data_parallel_rank == 0: world_tensor = loaded_state[model_idx][dtype][key]