From 32180d39d228cc0ee502555c3e8f6407a386b47f Mon Sep 17 00:00:00 2001 From: Jimmy Zhang Date: Fri, 19 Apr 2024 09:59:48 -0700 Subject: [PATCH] fix_weight_only_ckpt_save Signed-off-by: Jimmy Zhang --- nemo/collections/nlp/parts/nlp_overrides.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 17a5ac705185..473cc8f94bb9 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -346,14 +346,16 @@ def save_checkpoint( hasattr(self.lightning_module, 'sharded_state_dict') and self.lightning_module.sharded_state_dict() is not None ): - assert ( - len(checkpoint['optimizer_states']) == 1 - ), "Currently only support checkpointing 1 distributed optimizer per time!" - # converts the optimizer states to their sharded equivalents - sharded_optim_state = self.optimizer_sharded_state_dict( - unsharded_optim_state=checkpoint['optimizer_states'][0] - ) - checkpoint['optimizer_states'] = [sharded_optim_state] + if 'optimizer_states' in checkpoint: + assert ( + len(checkpoint['optimizer_states']) == 1 + ), "Currently only support checkpointing 1 distributed optimizer per time!" + # converts the optimizer states to their sharded equivalents + sharded_optim_state = self.optimizer_sharded_state_dict( + unsharded_optim_state=checkpoint['optimizer_states'][0] + ) + checkpoint['optimizer_states'] = [sharded_optim_state] + # dist_checkpointing expects a directory so we will name the directory # using the path with the file extension removed checkpoint_dir = ckpt_to_dir(filepath)