From dccd18960233f490df46765ce7ef14e1a8f30f49 Mon Sep 17 00:00:00 2001 From: jiemingz Date: Sun, 3 Mar 2024 20:28:24 -0800 Subject: [PATCH] avoid duplicate optimizer state dict fix Signed-off-by: jiemingz --- nemo/collections/nlp/parts/nlp_overrides.py | 12 ++++++++---- nemo/core/optim/distributed_adam.py | 6 ++++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 6560ebe4b37c..2b9100dc4e0a 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -252,7 +252,7 @@ def configure_ddp(self): else: super().configure_ddp() - def optimizer_sharded_state_dict(self): + def optimizer_sharded_state_dict(self, unsharded_optim_state=None): """ Sharded state dictionary for an MainParamsOptimizerWrapper. Used to save and load the optimizer state when training with distributed_checkpoint. @@ -272,7 +272,7 @@ def optimizer_sharded_state_dict(self): } if isinstance(optimizer, MegatronDistributedFusedAdam): - return optimizer.sharded_state_dict(model_sharded_state_dict) + return optimizer.sharded_state_dict(model_sharded_state_dict, unsharded_optim_state) elif not isinstance(optimizer, MainParamsOptimizerWrapper): # Regular optimizer, e.g. Adam or FusedAdam init_optimizer_states(optimizer) @@ -335,9 +335,13 @@ def save_checkpoint( hasattr(self.lightning_module, 'sharded_state_dict') and self.lightning_module.sharded_state_dict() is not None ): + assert len(checkpoint['optimizer_states']) == 1, \ + "Currently only support checkpointing 1 distributed optimizer per time!" # converts the optimizer states to their sharded equivalents - checkpoint['optimizer_states'] = [self.optimizer_sharded_state_dict()] - + sharded_optim_state = self.optimizer_sharded_state_dict( + unsharded_optim_state=checkpoint['optimizer_states'][0] + ) + checkpoint['optimizer_states'] = [sharded_optim_state] # dist_checkpointing expects a directory so we will name the directory # using the path with the file extension removed checkpoint_dir = ckpt_to_dir(filepath) diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index a2316dabb023..de207c1f46ac 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -549,8 +549,10 @@ def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedA # Handle any remaining dtype conversions super()._check_params_shard_dtypes(params_buckets) - def sharded_state_dict(self, model_sharded_state_dict): - optimizer_state_dict = self.state_dict() + + def sharded_state_dict(self, model_sharded_state_dict, optimizer_state_dict=None): + if optimizer_state_dict is None: + optimizer_state_dict = self.state_dict() id_to_sharded_param_map = get_param_id_to_sharded_param_map( model_sharded_state_dict=model_sharded_state_dict, optim_params_iter=self.parameters(),