From 702d23ac7628b4d648cdbaf746d4da93c82e1b4b Mon Sep 17 00:00:00 2001 From: JimmyZhang12 Date: Sat, 21 Sep 2024 00:52:10 +0000 Subject: [PATCH] Apply isort and black reformatting Signed-off-by: JimmyZhang12 --- nemo/collections/nlp/parts/nlp_overrides.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index a3d94d69babc..f720f293f369 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -20,9 +20,9 @@ import tempfile from collections import OrderedDict, defaultdict from contextlib import contextmanager +from dataclasses import asdict from pathlib import Path from typing import Any, Callable, Dict, Generator, Iterator, List, Literal, Mapping, Optional, Sized, Union -from dataclasses import asdict import pytorch_lightning as pl import torch @@ -97,6 +97,7 @@ try: from megatron.core import dist_checkpointing, parallel_state + from megatron.core.dist_checkpointing.core import maybe_load_config from megatron.core.dist_checkpointing.dict_utils import dict_list_map_outplace from megatron.core.dist_checkpointing.mapping import LocalNonpersitentObject from megatron.core.dist_checkpointing.optimizer import ( @@ -105,10 +106,10 @@ optim_state_to_sharding_state, ) from megatron.core.dist_checkpointing.strategies import tensorstore - from megatron.core.dist_checkpointing.core import maybe_load_config from megatron.core.tensor_parallel.layers import param_is_not_tensor_parallel_duplicate from megatron.core.transformer.module import Float16Module as MCoreFloat16Module from megatron.core.transformer.transformer_layer import TransformerLayer as MCoreTransformerLayer + from nemo.utils.callbacks.dist_ckpt_io import DistributedCheckpointIO HAVE_MEGATRON_CORE = True @@ -1353,7 +1354,7 @@ def dummy(): metadata = asdict(maybe_load_config(tmp_model_weights_dir)) sharded_state_dict = instance.sharded_state_dict(metadata=metadata) checkpoint['state_dict'] = sharded_state_dict - + checkpoint_io = DistributedCheckpointIO.from_config(conf) checkpoint = checkpoint_io.load_checkpoint( tmp_model_weights_dir,