diff --git a/examples/nlp/language_modeling/megatron_gpt_pretraining.py b/examples/nlp/language_modeling/megatron_gpt_pretraining.py index a9a0d71ed554..4a2f3a912e01 100644 --- a/examples/nlp/language_modeling/megatron_gpt_pretraining.py +++ b/examples/nlp/language_modeling/megatron_gpt_pretraining.py @@ -59,9 +59,6 @@ def main(cfg) -> None: # Start new pretraining or resume from a checkpoint if it exists else: model = MegatronGPTModel(cfg.model, trainer) - - s = torch.cuda.Stream() - torch.cuda.set_stream(s) trainer.fit(model) diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py index 34bfec98e67d..1d5d8d1ee288 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py @@ -41,6 +41,7 @@ from megatron.core.transformer.transformer_block import TransformerBlockSubmodules, get_num_layers_to_build from megatron.core.transformer.transformer_layer import BaseTransformerLayer from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint + from megatron.core.transformer.graphs import CudaGraphManager HAVE_MEGATRON_CORE = True @@ -188,9 +189,6 @@ def __init__(self, config, layer_number=1, hidden_dropout=None): self.config = config self.is_first_microbatch = True - self.sample_inputs = None - self.sample_outputs = None - self.enable_cuda_graph = config.enable_cuda_graph precision = 'bf16' if config.bf16 else 16 @@ -324,17 +322,9 @@ def sharded_state_dict(self, prefix: str = '', sharded_offsets: tuple = (), meta return sharded_state_dict def __call__(self, *args, **kwargs): - from megatron.core.transformer.graphs import CudaGraphManager - - if hasattr(self, 'enable_cuda_graph') and self.enable_cuda_graph and self.training: - if not hasattr(self, 'cudagraph_manager'): - self.add_module('cudagraph_manager', CudaGraphManager()) - - out = self.cudagraph_manager(self, args, kwargs) - else: - out = super(MegatronModule, self).__call__(*args, **kwargs) - return out - + if hasattr(self, 'cudagraph_manager'): + return self.cudagraph_manager(self, args, kwargs) + return super().__call__(*args, **kwargs) # Use this spec to use the full Transformer layer from Transformer Engine def get_gpt_full_te_layer_autocast_spec(transformer_config) -> ModuleSpec: