Skip to content

Commit

Permalink
avoid duplicate optimizer state dict fix
Browse files Browse the repository at this point in the history
Signed-off-by: jiemingz <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed Mar 4, 2024
1 parent 3211494 commit dccd189
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
12 changes: 8 additions & 4 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def configure_ddp(self):
else:
super().configure_ddp()

def optimizer_sharded_state_dict(self):
def optimizer_sharded_state_dict(self, unsharded_optim_state=None):
"""
Sharded state dictionary for an MainParamsOptimizerWrapper.
Used to save and load the optimizer state when training with distributed_checkpoint.
Expand All @@ -272,7 +272,7 @@ def optimizer_sharded_state_dict(self):
}

if isinstance(optimizer, MegatronDistributedFusedAdam):
return optimizer.sharded_state_dict(model_sharded_state_dict)
return optimizer.sharded_state_dict(model_sharded_state_dict, unsharded_optim_state)
elif not isinstance(optimizer, MainParamsOptimizerWrapper):
# Regular optimizer, e.g. Adam or FusedAdam
init_optimizer_states(optimizer)
Expand Down Expand Up @@ -335,9 +335,13 @@ def save_checkpoint(
hasattr(self.lightning_module, 'sharded_state_dict')
and self.lightning_module.sharded_state_dict() is not None
):
assert len(checkpoint['optimizer_states']) == 1, \
"Currently only support checkpointing 1 distributed optimizer per time!"
# converts the optimizer states to their sharded equivalents
checkpoint['optimizer_states'] = [self.optimizer_sharded_state_dict()]

sharded_optim_state = self.optimizer_sharded_state_dict(
unsharded_optim_state=checkpoint['optimizer_states'][0]
)
checkpoint['optimizer_states'] = [sharded_optim_state]
# dist_checkpointing expects a directory so we will name the directory
# using the path with the file extension removed
checkpoint_dir = ckpt_to_dir(filepath)
Expand Down
6 changes: 4 additions & 2 deletions nemo/core/optim/distributed_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,8 +549,10 @@ def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedA
# Handle any remaining dtype conversions
super()._check_params_shard_dtypes(params_buckets)

def sharded_state_dict(self, model_sharded_state_dict):
optimizer_state_dict = self.state_dict()

def sharded_state_dict(self, model_sharded_state_dict, optimizer_state_dict=None):
if optimizer_state_dict is None:
optimizer_state_dict = self.state_dict()

id_to_sharded_param_map = get_param_id_to_sharded_param_map(
model_sharded_state_dict=model_sharded_state_dict, optim_params_iter=self.parameters(),
Expand Down

0 comments on commit dccd189

Please sign in to comment.