Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed Aug 9, 2024
1 parent b1d632c commit 74e2725
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 17 deletions.
3 changes: 0 additions & 3 deletions examples/nlp/language_modeling/megatron_gpt_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,6 @@ def main(cfg) -> None:
# Start new pretraining or resume from a checkpoint if it exists
else:
model = MegatronGPTModel(cfg.model, trainer)

s = torch.cuda.Stream()
torch.cuda.set_stream(s)

trainer.fit(model)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules, get_num_layers_to_build
from megatron.core.transformer.transformer_layer import BaseTransformerLayer
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
from megatron.core.transformer.graphs import CudaGraphManager

HAVE_MEGATRON_CORE = True

Expand Down Expand Up @@ -188,9 +189,6 @@ def __init__(self, config, layer_number=1, hidden_dropout=None):

self.config = config
self.is_first_microbatch = True
self.sample_inputs = None
self.sample_outputs = None
self.enable_cuda_graph = config.enable_cuda_graph

precision = 'bf16' if config.bf16 else 16

Expand Down Expand Up @@ -324,17 +322,9 @@ def sharded_state_dict(self, prefix: str = '', sharded_offsets: tuple = (), meta
return sharded_state_dict

def __call__(self, *args, **kwargs):
from megatron.core.transformer.graphs import CudaGraphManager

if hasattr(self, 'enable_cuda_graph') and self.enable_cuda_graph and self.training:
if not hasattr(self, 'cudagraph_manager'):
self.add_module('cudagraph_manager', CudaGraphManager())

out = self.cudagraph_manager(self, args, kwargs)
else:
out = super(MegatronModule, self).__call__(*args, **kwargs)
return out

if hasattr(self, 'cudagraph_manager'):
return self.cudagraph_manager(self, args, kwargs)
return super().__call__(*args, **kwargs)

# Use this spec to use the full Transformer layer from Transformer Engine
def get_gpt_full_te_layer_autocast_spec(transformer_config) -> ModuleSpec:
Expand Down

0 comments on commit 74e2725

Please sign in to comment.