Skip to content

Commit

Permalink
add cudagraph manager
Browse files Browse the repository at this point in the history
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed Aug 20, 2024
1 parent 24b201c commit 0d51bf0
Showing 1 changed file with 8 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,12 @@ def __init__(self, config, layer_number=1, hidden_dropout=None):
transformer_layer_args["ub_atomic_gemm_rs"] = config.tp_comm_atomic_rs
super().__init__(**transformer_layer_args)

if config.enable_cuda_graph and self.training:
assert (
not config.cpu_offloading and config.recompute_granularity is None
), "Cudagraphs not supported"
self.add_module('cudagraph_manager', CudaGraphManager())

# Called by MCore's TransformerBlock.forward
# megatron/core/transformer/transformer_block.py
def forward(
Expand All @@ -266,8 +272,8 @@ def forward(
self.is_first_microbatch = False
context = None

# CUDA graph requires returned values to be Tensors
if self.config.enable_cuda_graph and self.training:
# External CUDA graph requires returned values to be Tensors
if hasattr(self.config, 'external_cuda_graph') and self.config.external_cuda_graph and self.training:
return hidden_states
return hidden_states, context

Expand Down

0 comments on commit 0d51bf0

Please sign in to comment.