Skip to content

Commit

Permalink
fix loading with torch dist ckpt
Browse files Browse the repository at this point in the history
Signed-off-by: jiemingz <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed Mar 23, 2024
1 parent a44253d commit 7fbf9e2
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,13 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
# after dist_checkpointing.load, sharded tensors will be replaced with tensors
checkpoint['state_dict'] = sharded_state_dict
checkpoint['optimizer_states'] = [self.optimizer_sharded_state_dict()]
strategy = tensorstore.TensorStoreLoadShardedStrategy(load_directly_on_device=True)

if self.torch_dist_ckpt:
sharded_strategy = ('torch_dist', 1)
else:
sharded_strategy = tensorstore.TensorStoreLoadShardedStrategy(load_directly_on_device=True)
checkpoint = dist_checkpointing.load(
sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_path, sharded_strategy=strategy
sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_path, sharded_strategy=sharded_strategy
)

return checkpoint
Expand Down

0 comments on commit 7fbf9e2

Please sign in to comment.