From b1d632cd695f74379c71685ed376206466711eb8 Mon Sep 17 00:00:00 2001 From: Jimmy Zhang Date: Mon, 29 Jul 2024 15:51:27 -0700 Subject: [PATCH] cuda graph modules Signed-off-by: Jimmy Zhang --- .../megatron_gpt_pretraining.py | 3 +++ .../gpt_full_te_layer_autocast_spec.py | 23 +++++++++++++++---- .../language_modeling/megatron_gpt_model.py | 8 ++++--- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/examples/nlp/language_modeling/megatron_gpt_pretraining.py b/examples/nlp/language_modeling/megatron_gpt_pretraining.py index 422319a382c8..a9a0d71ed554 100644 --- a/examples/nlp/language_modeling/megatron_gpt_pretraining.py +++ b/examples/nlp/language_modeling/megatron_gpt_pretraining.py @@ -60,6 +60,9 @@ def main(cfg) -> None: else: model = MegatronGPTModel(cfg.model, trainer) + s = torch.cuda.Stream() + torch.cuda.set_stream(s) + trainer.fit(model) diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py index f3299d488fd0..34bfec98e67d 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py @@ -20,6 +20,7 @@ from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults from nemo.collections.nlp.parts import utils_funcs +_IS_GRAPH_CAPTURING=False try: from transformer_engine.pytorch import TransformerLayer @@ -187,6 +188,10 @@ def __init__(self, config, layer_number=1, hidden_dropout=None): self.config = config self.is_first_microbatch = True + self.sample_inputs = None + self.sample_outputs = None + self.enable_cuda_graph = config.enable_cuda_graph + precision = 'bf16' if config.bf16 else 16 transformer_layer_args = { @@ -263,12 +268,8 @@ def forward( # checkpoint_core_attention, ) self.is_first_microbatch = False - context = None - # CUDA graph requires returned values to be Tensors - if self.config.enable_cuda_graph and self.training: - return hidden_states - return hidden_states, context + return hidden_states, None def _get_layer_offset(self): @@ -322,6 +323,18 @@ def sharded_state_dict(self, prefix: str = '', sharded_offsets: tuple = (), meta return sharded_state_dict + def __call__(self, *args, **kwargs): + from megatron.core.transformer.graphs import CudaGraphManager + + if hasattr(self, 'enable_cuda_graph') and self.enable_cuda_graph and self.training: + if not hasattr(self, 'cudagraph_manager'): + self.add_module('cudagraph_manager', CudaGraphManager()) + + out = self.cudagraph_manager(self, args, kwargs) + else: + out = super(MegatronModule, self).__call__(*args, **kwargs) + return out + # Use this spec to use the full Transformer layer from Transformer Engine def get_gpt_full_te_layer_autocast_spec(transformer_config) -> ModuleSpec: diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 6e7a145679e0..291c4576d2ea 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -811,9 +811,11 @@ def training_step(self, dataloader_iter): module = module.module if not self.mcore_gpt: module = module.language_model - if hasattr(module, 'embedding'): - for param in module.embedding.parameters(): - param.data_ptr() + # if hasattr(module, 'embedding'): + # for param in module.embedding.parameters(): + # param.data_ptr() + for param in module.parameters(): + param.data_ptr() if self.cfg.get('pipeline_model_parallel_size', 1) > 1 and parallel_state.is_pipeline_last_stage( ignore_virtual=True