From 8b07dab2e998126a480a292374cb8a312f452b64 Mon Sep 17 00:00:00 2001 From: Jimmy Zhang Date: Thu, 1 Feb 2024 10:28:05 -0800 Subject: [PATCH] use parallel_state instead of self Signed-off-by: Jimmy Zhang --- .../language_modeling/megatron_gpt_model.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index a50c8c7265a3..6c31e80e3463 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -287,8 +287,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): # Convert the global-batch-based profile index to micro-batch index if hasattr(self, '_nsys_profile_enabled'): mp_size = cfg.get('tensor_model_parallel_size', 1) * cfg.get('pipeline_model_parallel_size', 1) - self.cp_size = cfg.get('context_parallel_size', 1) - data_parallel_world_size = trainer.world_size // (mp_size * self.cp_size) + cp_size = cfg.get('context_parallel_size', 1) + data_parallel_world_size = trainer.world_size // (mp_size * cp_size) grad_accum_steps = cfg.get('global_batch_size') // (cfg.get('micro_batch_size') * data_parallel_world_size) self._nsys_profile_start_step *= grad_accum_steps self._nsys_profile_end_step *= grad_accum_steps @@ -306,8 +306,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): if self.use_loss_mask and self.transformer_config.sequence_parallel: raise ValueError('Loss mask is not supported with sequence parallelism.') - self.cp_size = self.cfg.get('context_parallel_size', 1) - def set_inference_config(self, inference_config): self._inference_config = inference_config @@ -863,18 +861,19 @@ def get_batch_on_this_context_parallel_rank(self, batch): if 'loss_mask' in batch and batch['loss_mask'] is not None: num_valid_tokens_in_ub = batch['loss_mask'].sum() - if self.cp_size > 1: + cp_size = parallel_state.get_context_parallel_world_size() + if cp_size > 1: cp_rank = parallel_state.get_context_parallel_rank() for key, val in batch.items(): if val is not None: seq_dim = 1 if key != 'attention_mask' else 2 val = val.view( *val.shape[0:seq_dim], - 2 * self.cp_size, - val.shape[seq_dim] // (2 * self.cp_size), + 2 * cp_size, + val.shape[seq_dim] // (2 * cp_size), *val.shape[(seq_dim + 1) :], ) - index = torch.tensor([cp_rank, (2 * self.cp_size - cp_rank - 1)], device=val.device) + index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=val.device) val = val.index_select(seq_dim, index) val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :]) batch[key] = val @@ -958,6 +957,7 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ def loss_func(output_tensor): # Loss for a micro-batch (ub) loss_for_ub = self.loss_func(batch['loss_mask'], batch['num_valid_tokens_in_ub'], output_tensor) + cp_size = parallel_state.get_context_parallel_world_size() if validation_step and not self.cfg.data.get('validation_drop_last', True): num_valid_tokens_in_ub = batch['num_valid_tokens_in_ub'] if loss_for_ub.isnan(): @@ -976,10 +976,10 @@ def loss_func(output_tensor): torch.distributed.all_reduce( loss_sum_and_ub_size_all_gpu, group=parallel_state.get_data_parallel_group() ) - return loss_for_ub * self.cp_size, {'loss_sum_and_ub_size': loss_sum_and_ub_size_all_gpu} + return loss_for_ub * cp_size, {'loss_sum_and_ub_size': loss_sum_and_ub_size_all_gpu} else: reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) - return loss_for_ub * self.cp_size, {'avg': reduced_loss} + return loss_for_ub * cp_size, {'avg': reduced_loss} return output_tensor, loss_func @@ -1113,7 +1113,7 @@ def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor): loss_mask = loss_mask.view(-1).float() # TODO: add nemo version here loss = torch.sum(losses.view(-1) * loss_mask) / num_valid_tokens_in_ub # sequence level nll - if self.cp_size > 1: + if parallel_state.get_context_parallel_world_size() > 1: torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group()) return loss