Skip to content

Commit

Permalink
use parallel_state instead of self
Browse files Browse the repository at this point in the history
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed Feb 2, 2024
1 parent e247926 commit 8b07dab
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
# Convert the global-batch-based profile index to micro-batch index
if hasattr(self, '_nsys_profile_enabled'):
mp_size = cfg.get('tensor_model_parallel_size', 1) * cfg.get('pipeline_model_parallel_size', 1)
self.cp_size = cfg.get('context_parallel_size', 1)
data_parallel_world_size = trainer.world_size // (mp_size * self.cp_size)
cp_size = cfg.get('context_parallel_size', 1)
data_parallel_world_size = trainer.world_size // (mp_size * cp_size)
grad_accum_steps = cfg.get('global_batch_size') // (cfg.get('micro_batch_size') * data_parallel_world_size)
self._nsys_profile_start_step *= grad_accum_steps
self._nsys_profile_end_step *= grad_accum_steps
Expand All @@ -306,8 +306,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
if self.use_loss_mask and self.transformer_config.sequence_parallel:
raise ValueError('Loss mask is not supported with sequence parallelism.')

self.cp_size = self.cfg.get('context_parallel_size', 1)

def set_inference_config(self, inference_config):
self._inference_config = inference_config

Expand Down Expand Up @@ -863,18 +861,19 @@ def get_batch_on_this_context_parallel_rank(self, batch):
if 'loss_mask' in batch and batch['loss_mask'] is not None:
num_valid_tokens_in_ub = batch['loss_mask'].sum()

if self.cp_size > 1:
cp_size = parallel_state.get_context_parallel_world_size()
if cp_size > 1:
cp_rank = parallel_state.get_context_parallel_rank()
for key, val in batch.items():
if val is not None:
seq_dim = 1 if key != 'attention_mask' else 2
val = val.view(
*val.shape[0:seq_dim],
2 * self.cp_size,
val.shape[seq_dim] // (2 * self.cp_size),
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1) :],
)
index = torch.tensor([cp_rank, (2 * self.cp_size - cp_rank - 1)], device=val.device)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=val.device)
val = val.index_select(seq_dim, index)
val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
batch[key] = val
Expand Down Expand Up @@ -958,6 +957,7 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
def loss_func(output_tensor):
# Loss for a micro-batch (ub)
loss_for_ub = self.loss_func(batch['loss_mask'], batch['num_valid_tokens_in_ub'], output_tensor)
cp_size = parallel_state.get_context_parallel_world_size()
if validation_step and not self.cfg.data.get('validation_drop_last', True):
num_valid_tokens_in_ub = batch['num_valid_tokens_in_ub']
if loss_for_ub.isnan():
Expand All @@ -976,10 +976,10 @@ def loss_func(output_tensor):
torch.distributed.all_reduce(
loss_sum_and_ub_size_all_gpu, group=parallel_state.get_data_parallel_group()
)
return loss_for_ub * self.cp_size, {'loss_sum_and_ub_size': loss_sum_and_ub_size_all_gpu}
return loss_for_ub * cp_size, {'loss_sum_and_ub_size': loss_sum_and_ub_size_all_gpu}
else:
reduced_loss = average_losses_across_data_parallel_group([loss_for_ub])
return loss_for_ub * self.cp_size, {'avg': reduced_loss}
return loss_for_ub * cp_size, {'avg': reduced_loss}

return output_tensor, loss_func

Expand Down Expand Up @@ -1113,7 +1113,7 @@ def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor):
loss_mask = loss_mask.view(-1).float()
# TODO: add nemo version here
loss = torch.sum(losses.view(-1) * loss_mask) / num_valid_tokens_in_ub # sequence level nll
if self.cp_size > 1:
if parallel_state.get_context_parallel_world_size() > 1:
torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group())
return loss

Expand Down

0 comments on commit 8b07dab

Please sign in to comment.