From 83546266c06cac637733b6c1feb91ac9017f2de5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 4 Mar 2024 19:53:56 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- nemo/collections/nlp/parts/nlp_overrides.py | 13 ++++++------- nemo/core/optim/distributed_adam.py | 1 - 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 4cc12c752b61..d7a954eb5423 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -336,8 +336,9 @@ def save_checkpoint( hasattr(self.lightning_module, 'sharded_state_dict') and self.lightning_module.sharded_state_dict() is not None ): - assert len(checkpoint['optimizer_states']) == 1, \ - "Currently only support checkpointing 1 distributed optimizer per time!" + assert ( + len(checkpoint['optimizer_states']) == 1 + ), "Currently only support checkpointing 1 distributed optimizer per time!" # converts the optimizer states to their sharded equivalents sharded_optim_state = self.optimizer_sharded_state_dict( unsharded_optim_state=checkpoint['optimizer_states'][0] @@ -434,15 +435,13 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: # after dist_checkpointing.load, sharded tensors will be replaced with tensors checkpoint['state_dict'] = sharded_state_dict - checkpoint['optimizer_states'] = [self.optimizer_sharded_state_dict()] + checkpoint['optimizer_states'] = [self.optimizer_sharded_state_dict()] strategy = dist_checkpointing.strategies.tensorstore.TensorStoreLoadShardedStrategy( load_directly_on_device=True ) checkpoint = dist_checkpointing.load( - sharded_state_dict=checkpoint, - checkpoint_dir=checkpoint_path, - sharded_strategy=strategy - ) + sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_path, sharded_strategy=strategy + ) return checkpoint diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index 8fece107a548..fe43932c53d6 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -511,7 +511,6 @@ def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedA # Handle any remaining dtype conversions super()._check_params_shard_dtypes(params_buckets) - def sharded_state_dict(self, model_sharded_state_dict, optimizer_state_dict=None): if optimizer_state_dict is None: optimizer_state_dict = self.state_dict()