Skip to content

Commit

Permalink
add first_val_step for mcore schedules
Browse files Browse the repository at this point in the history
Signed-off-by: jiemingz <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed Jan 11, 2024
1 parent 0a1a5b1 commit 2e280bc
Showing 1 changed file with 7 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):

self.mcore_gpt = cfg.get('mcore_gpt', False)
self.spec_name = cfg.get('name', '')
if cfg.get('fp8', False):
self.prev_step_training = True

self.rampup_batch_size = self.cfg.get('rampup_batch_size', None)
if self.rampup_batch_size:
Expand Down Expand Up @@ -485,7 +487,7 @@ def forward(self, tokens, text_position_ids, attention_mask, labels):
output_tensor = self.model(tokens, text_position_ids, attention_mask, labels=labels)
return output_tensor

def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only, first_val_step=False):

# handle asynchronous grad reduction
no_sync_func = None
Expand Down Expand Up @@ -515,6 +517,7 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
forward_only=forward_only,
seq_length=self.cfg.encoder_seq_length,
micro_batch_size=self.cfg.micro_batch_size,
first_val_step=first_val_step,
)

# only the last stages of the pipeline return losses
Expand Down Expand Up @@ -1012,9 +1015,11 @@ def validation_step(self, dataloader_iter, batch_idx):
if isinstance(self.model, list):
for model_module in self.model:
model_module.eval()
first_val_step = self.prev_step_training and not self.training

loss = self.fwd_bwd_step(dataloader_iter, batch_idx, True)
loss = self.fwd_bwd_step(dataloader_iter, batch_idx, True, first_val_step)

self.prev_step_training = self.training
if isinstance(self.model, list):
for model_module in self.model:
model_module.train()
Expand Down

0 comments on commit 2e280bc

Please sign in to comment.