Skip to content

Commit

Permalink
load checkpoint directly to GPU
Browse files Browse the repository at this point in the history
Signed-off-by: jiemingz <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed Mar 4, 2024
1 parent dccd189 commit e5880c3
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,11 +433,15 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:

# after dist_checkpointing.load, sharded tensors will be replaced with tensors
checkpoint['state_dict'] = sharded_state_dict
checkpoint['optimizer_states'] = [self.optimizer_sharded_state_dict()]

checkpoint = dist_checkpointing.load(sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_path)

checkpoint = self._fix_tensors_device(checkpoint)
checkpoint['optimizer_states'] = [self.optimizer_sharded_state_dict()]
strategy = dist_checkpointing.strategies.tensorstore.TensorStoreLoadShardedStrategy(
load_directly_on_device=True
)
checkpoint = dist_checkpointing.load(
sharded_state_dict=checkpoint,
checkpoint_dir=checkpoint_path,
sharded_strategy=strategy
)

return checkpoint

Expand Down

0 comments on commit e5880c3

Please sign in to comment.