Skip to content

Commit

Permalink
add is_first_microbatch
Browse files Browse the repository at this point in the history
Signed-off-by: jiemingz <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed Jan 3, 2024
1 parent be00d21 commit 14e8d54
Showing 1 changed file with 14 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
)

self.transformer_engine = cfg.get('transformer_engine', False)

if self.mcore_gpt and self.cfg.get('fp8', False):
self.num_microbatches_in_previous_step = -1
self.microbatch_count = 0

# configuration used for inference
self._inference_config = None

Expand Down Expand Up @@ -880,9 +883,19 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
forward_args['cu_seqlens_q'] = cu_seqlens
forward_args['cu_seqlens_kv'] = cu_seqlens
forward_args['qkv_format'] = 'thd'

if self.transformer_engine and self.cfg.get('fp8', False):
# Determine if the current iteration is first microbatch
if self.num_microbatches_in_previous_step != get_num_microbatches():
self.microbatch_count = 0 # Reset count on new batch size rampup interval
self.num_microbatches_in_previous_step = get_num_microbatches()
is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0
forward_args['is_first_microbatch'] = is_first_microbatch

output_tensor = model(**forward_args)

self.microbatch_count += 1

def loss_func(output_tensor):
# Loss for a micro-batch (ub)
loss_for_ub = self.loss_func(batch['loss_mask'], output_tensor)
Expand Down

0 comments on commit 14e8d54

Please sign in to comment.