diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 1aed60b4f1de..ebcca1df697d 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -66,6 +66,7 @@ try: from apex.transformer.pipeline_parallel.utils import get_num_microbatches + from nemo.core.optim.distributed_adam import MegatronDistributedFusedAdam HAVE_APEX = True @@ -87,12 +88,12 @@ try: from megatron.core import dist_checkpointing, parallel_state from megatron.core.dist_checkpointing.dict_utils import dict_list_map_outplace - from megatron.core.dist_checkpointing.strategies import tensorstore from megatron.core.dist_checkpointing.optimizer import ( get_param_id_to_sharded_param_map, make_sharded_optimizer_tensor, optim_state_to_sharding_state, ) + from megatron.core.dist_checkpointing.strategies import tensorstore from megatron.core.transformer.module import Float16Module as MCoreFloat16Module from megatron.core.transformer.transformer_layer import TransformerLayer as MCoreTransformerLayer