From d8d3158d4fc96de1cb841ce66c69fdb0d9f4c6f7 Mon Sep 17 00:00:00 2001 From: jiemingz Date: Wed, 3 Jan 2024 21:30:47 -0800 Subject: [PATCH] add VP support Signed-off-by: jiemingz --- .../models/language_modeling/megatron_gpt_model.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index ee49ce326486..9b9f9695a382 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -889,7 +889,16 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ if self.num_microbatches_in_previous_step != get_num_microbatches(): self.microbatch_count = 0 # Reset count on new batch size rampup interval self.num_microbatches_in_previous_step = get_num_microbatches() - is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0 + + vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() + if vp_size is not None: + #compute fp8 weights first time a model chunk processes a microbatch + pp_size = parallel_state.get_pipeline_model_parallel_world_size() + forwards_per_step = vp_size * get_num_microbatches() + is_first_microbatch = (self.microbatch_count % forwards_per_step < vp_size*pp_size) and \ + (self.microbatch_count % pp_size == 0) + else: + is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0 forward_args['is_first_microbatch'] = is_first_microbatch output_tensor = model(**forward_args)