Skip to content

Commit

Permalink
add VP support
Browse files Browse the repository at this point in the history
Signed-off-by: jiemingz <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed Jan 4, 2024
1 parent b6c2f59 commit d8d3158
Showing 1 changed file with 10 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,16 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
if self.num_microbatches_in_previous_step != get_num_microbatches():
self.microbatch_count = 0 # Reset count on new batch size rampup interval
self.num_microbatches_in_previous_step = get_num_microbatches()
is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0

vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
if vp_size is not None:
#compute fp8 weights first time a model chunk processes a microbatch
pp_size = parallel_state.get_pipeline_model_parallel_world_size()
forwards_per_step = vp_size * get_num_microbatches()
is_first_microbatch = (self.microbatch_count % forwards_per_step < vp_size*pp_size) and \
(self.microbatch_count % pp_size == 0)
else:
is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0
forward_args['is_first_microbatch'] = is_first_microbatch

output_tensor = model(**forward_args)
Expand Down

0 comments on commit d8d3158

Please sign in to comment.