From 7fbf9e226cc18b43e1c0b9307eb380726d1b977a Mon Sep 17 00:00:00 2001 From: jiemingz Date: Sat, 23 Mar 2024 01:17:24 -0700 Subject: [PATCH] fix loading with torch dist ckpt Signed-off-by: jiemingz --- nemo/collections/nlp/parts/nlp_overrides.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index a84785bcf407..bfbc916c89d1 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -442,9 +442,13 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: # after dist_checkpointing.load, sharded tensors will be replaced with tensors checkpoint['state_dict'] = sharded_state_dict checkpoint['optimizer_states'] = [self.optimizer_sharded_state_dict()] - strategy = tensorstore.TensorStoreLoadShardedStrategy(load_directly_on_device=True) + + if self.torch_dist_ckpt: + sharded_strategy = ('torch_dist', 1) + else: + sharded_strategy = tensorstore.TensorStoreLoadShardedStrategy(load_directly_on_device=True) checkpoint = dist_checkpointing.load( - sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_path, sharded_strategy=strategy + sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_path, sharded_strategy=sharded_strategy ) return checkpoint