diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 2b9100dc4e0a..b57bb64f93a0 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -433,11 +433,15 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: # after dist_checkpointing.load, sharded tensors will be replaced with tensors checkpoint['state_dict'] = sharded_state_dict - checkpoint['optimizer_states'] = [self.optimizer_sharded_state_dict()] - - checkpoint = dist_checkpointing.load(sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_path) - - checkpoint = self._fix_tensors_device(checkpoint) + checkpoint['optimizer_states'] = [self.optimizer_sharded_state_dict()] + strategy = dist_checkpointing.strategies.tensorstore.TensorStoreLoadShardedStrategy( + load_directly_on_device=True + ) + checkpoint = dist_checkpointing.load( + sharded_state_dict=checkpoint, + checkpoint_dir=checkpoint_path, + sharded_strategy=strategy + ) return checkpoint