diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 17a5ac705185..473cc8f94bb9 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -346,14 +346,16 @@ def save_checkpoint( hasattr(self.lightning_module, 'sharded_state_dict') and self.lightning_module.sharded_state_dict() is not None ): - assert ( - len(checkpoint['optimizer_states']) == 1 - ), "Currently only support checkpointing 1 distributed optimizer per time!" - # converts the optimizer states to their sharded equivalents - sharded_optim_state = self.optimizer_sharded_state_dict( - unsharded_optim_state=checkpoint['optimizer_states'][0] - ) - checkpoint['optimizer_states'] = [sharded_optim_state] + if 'optimizer_states' in checkpoint: + assert ( + len(checkpoint['optimizer_states']) == 1 + ), "Currently only support checkpointing 1 distributed optimizer per time!" + # converts the optimizer states to their sharded equivalents + sharded_optim_state = self.optimizer_sharded_state_dict( + unsharded_optim_state=checkpoint['optimizer_states'][0] + ) + checkpoint['optimizer_states'] = [sharded_optim_state] + # dist_checkpointing expects a directory so we will name the directory # using the path with the file extension removed checkpoint_dir = ckpt_to_dir(filepath)