Skip to content

Commit

Permalink
fix_weight_only_ckpt_save
Browse files Browse the repository at this point in the history
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed Apr 19, 2024
1 parent 3865746 commit 32180d3
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,14 +346,16 @@ def save_checkpoint(
hasattr(self.lightning_module, 'sharded_state_dict')
and self.lightning_module.sharded_state_dict() is not None
):
assert (
len(checkpoint['optimizer_states']) == 1
), "Currently only support checkpointing 1 distributed optimizer per time!"
# converts the optimizer states to their sharded equivalents
sharded_optim_state = self.optimizer_sharded_state_dict(
unsharded_optim_state=checkpoint['optimizer_states'][0]
)
checkpoint['optimizer_states'] = [sharded_optim_state]
if 'optimizer_states' in checkpoint:
assert (
len(checkpoint['optimizer_states']) == 1
), "Currently only support checkpointing 1 distributed optimizer per time!"
# converts the optimizer states to their sharded equivalents
sharded_optim_state = self.optimizer_sharded_state_dict(
unsharded_optim_state=checkpoint['optimizer_states'][0]
)
checkpoint['optimizer_states'] = [sharded_optim_state]

# dist_checkpointing expects a directory so we will name the directory
# using the path with the file extension removed
checkpoint_dir = ckpt_to_dir(filepath)
Expand Down

0 comments on commit 32180d3

Please sign in to comment.