Skip to content

Commit

Permalink
Fix SFT missing arg
Browse files Browse the repository at this point in the history
Signed-off-by: jiemingz <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed Jan 24, 2024
1 parent 861b633 commit 03d4f9c
Showing 1 changed file with 2 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def _determine_log_key(self, data_config, dataloader_idx, metric_name, mode):
else:
return base_key + f"dataloader{dataloader_idx}"

def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only, first_val_step=None):
batch = next(dataloader_iter)

log_token_counts = self.cfg.get('log_token_counts', False)
Expand Down Expand Up @@ -360,6 +360,7 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
forward_only=forward_only,
seq_length=seq_length,
micro_batch_size=get_micro_batch_size(),
first_val_step=first_val_step,
)

# only the last stages of the pipeline return losses
Expand Down

0 comments on commit 03d4f9c

Please sign in to comment.