From 83877fdf0a763847ff80c9d8e28bfa786977e866 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 Mar 2024 23:05:28 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- nemo/collections/nlp/parts/nlp_overrides.py | 13 ++++++------- nemo/core/optim/distributed_adam.py | 1 - 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 00e27230eac2..26443c248431 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -334,8 +334,9 @@ def save_checkpoint( hasattr(self.lightning_module, 'sharded_state_dict') and self.lightning_module.sharded_state_dict() is not None ): - assert len(checkpoint['optimizer_states']) == 1, \ - "Currently only support checkpointing 1 distributed optimizer per time!" + assert ( + len(checkpoint['optimizer_states']) == 1 + ), "Currently only support checkpointing 1 distributed optimizer per time!" # converts the optimizer states to their sharded equivalents sharded_optim_state = self.optimizer_sharded_state_dict( unsharded_optim_state=checkpoint['optimizer_states'][0] @@ -433,15 +434,13 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: # after dist_checkpointing.load, sharded tensors will be replaced with tensors checkpoint['state_dict'] = sharded_state_dict - checkpoint['optimizer_states'] = [self.optimizer_sharded_state_dict()] + checkpoint['optimizer_states'] = [self.optimizer_sharded_state_dict()] strategy = dist_checkpointing.strategies.tensorstore.TensorStoreLoadShardedStrategy( load_directly_on_device=True ) checkpoint = dist_checkpointing.load( - sharded_state_dict=checkpoint, - checkpoint_dir=checkpoint_path, - sharded_strategy=strategy - ) + sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_path, sharded_strategy=strategy + ) return checkpoint diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index de207c1f46ac..a85747c9f640 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -549,7 +549,6 @@ def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedA # Handle any remaining dtype conversions super()._check_params_shard_dtypes(params_buckets) - def sharded_state_dict(self, model_sharded_state_dict, optimizer_state_dict=None): if optimizer_state_dict is None: optimizer_state_dict = self.state_dict()