From 9de6189affca2c9ab67378ced7002ac4188c6ae3 Mon Sep 17 00:00:00 2001 From: Jimmy Zhang Date: Wed, 28 Aug 2024 16:56:40 -0700 Subject: [PATCH] update mcore changes Signed-off-by: Jimmy Zhang --- .../megatron/gpt_full_te_layer_autocast_spec.py | 2 +- .../nlp/models/language_modeling/megatron_gpt_model.py | 9 ++------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py index d94d218ef5d0..6ca9e83ce0e6 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py @@ -36,7 +36,7 @@ try: from megatron.core import parallel_state, tensor_parallel from megatron.core.fusions.fused_layer_norm import FusedLayerNorm - from megatron.core.transformer.graphs import CudaGraphManager + from megatron.core.transformer.cuda_graphs import CudaGraphManager from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_block import TransformerBlockSubmodules, get_num_layers_to_build from megatron.core.transformer.transformer_layer import BaseTransformerLayer diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 327a9990801b..22da07922eb2 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -835,14 +835,9 @@ def training_step(self, dataloader_iter): if not self.mcore_gpt: module = module.language_model - # Cudagraphed model does not trigger param sync hooks, so manually trigger param syncs here. - if self.cfg.get('enable_cuda_graph', False): - for param in module.parameters(): + if hasattr(module, 'embedding'): + for param in module.embedding.parameters(): param.data_ptr() - else: - if hasattr(module, 'embedding'): - for param in module.embedding.parameters(): - param.data_ptr() if self.cfg.get('pipeline_model_parallel_size', 1) > 1 and parallel_state.is_pipeline_last_stage( ignore_virtual=True