Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: JimmyZhang12 <JimmyZhang12@users.noreply.github.com>
  • Loading branch information
JimmyZhang12 committed Aug 20, 2024
1 parent f4530de commit 503cbf7
Showing 1 changed file with 3 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@
try:
from megatron.core import parallel_state, tensor_parallel
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.transformer.graphs import CudaGraphManager
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules, get_num_layers_to_build
from megatron.core.transformer.transformer_layer import BaseTransformerLayer
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
from megatron.core.transformer.graphs import CudaGraphManager

HAVE_MEGATRON_CORE = True

Expand Down Expand Up @@ -241,9 +241,7 @@ def __init__(self, config, layer_number=1, hidden_dropout=None):
super().__init__(**transformer_layer_args)

if not hasattr(self.config, 'external_cuda_graph') and config.enable_cuda_graph and self.training:
assert (
not config.cpu_offloading and config.recompute_granularity is None
), "Cudagraphs not supported"
assert not config.cpu_offloading and config.recompute_granularity is None, "Cudagraphs not supported"
self.add_module('cudagraph_manager', CudaGraphManager())

# Called by MCore's TransformerBlock.forward
Expand Down Expand Up @@ -334,6 +332,7 @@ def __call__(self, *args, **kwargs):
return self.cudagraph_manager(self, args, kwargs)
return super().__call__(*args, **kwargs)


# Use this spec to use the full Transformer layer from Transformer Engine
def get_gpt_full_te_layer_autocast_spec(transformer_config) -> ModuleSpec:
if not HAVE_MEGATRON_CORE or not HAVE_TE:
Expand Down

0 comments on commit 503cbf7

Please sign in to comment.