diff --git a/Jenkinsfile b/Jenkinsfile index f6253b16a6d4..14f9a38a9c17 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -212,10 +212,53 @@ pipeline { model.unet_config.use_flash_attention=False \ model.unet_config.attention_resolutions=[1] \ model.unet_config.channel_mult=[1] \ + model.ddp_overlap=False \ " sh "rm -rf /home/TestData/multimodal/stable_diffusion_train" } } + stage('L2: Multimodal Stable Diffusion Train with Cuda Graph') { + when { + anyOf { + branch 'main' + changeRequest target: 'main' + } + } + failFast true + steps { + sh "rm -rf /home/TestData/multimodal/stable_diffusion_train_with_cuda_graphs" + sh "python examples/multimodal/text_to_image/stable_diffusion/sd_train.py \ + trainer.precision=16 \ + trainer.num_nodes=1 \ + trainer.devices=1 \ + ++exp_manager.max_time_per_run=00:00:03:00 \ + exp_manager.exp_dir=/home/TestData/multimodal/stable_diffusion_train_with_cuda_graph \ + trainer.max_steps=20 \ + model.micro_batch_size=1 \ + model.global_batch_size=1 \ + model.data.synthetic_data=True \ + model.first_stage_key=images_moments \ + model.cond_stage_key=clip_encoded \ + model.optim.name=megatron_fused_adam \ + +model.optim.capturable=True \ + exp_manager.ema.enable=False \ + model.cond_stage_config._target_=nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder \ + ++model.cond_stage_config.version=openai/clip-vit-large-patch14 \ + ++model.cond_stage_config.max_length=77 \ + model.inductor=False \ + ~model.cond_stage_config.restore_from_path \ + ~model.cond_stage_config.freeze \ + ~model.cond_stage_config.layer \ + model.first_stage_config.from_pretrained=null \ + model.ddp_overlap=False \ + model.capture_cudagraph_iters=15 \ + model.unet_config.use_flash_attention=False \ + model.unet_config.attention_resolutions=[1] \ + model.unet_config.channel_mult=[1] \ + " + sh "rm -rf /home/TestData/multimodal/stable_diffusion_train_with_cuda_graphs" + } + } // stage('L2: Multimodal ControlNet Train') { // when { // anyOf { diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_lora_train.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_lora_train.yaml index 3fbe03aaeaa1..d9981a093288 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_lora_train.yaml +++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_lora_train.yaml @@ -119,7 +119,7 @@ model: use_checkpoint: False legacy: False use_flash_attention: True - enable_amp_o2_fp16: False + unet_precision: fp32 resblock_gn_groups: 32 lora_network_alpha: null @@ -214,4 +214,4 @@ model: row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers weight_tying: False - position_embedding_strategy: null # used only when weight_tying is True \ No newline at end of file + position_embedding_strategy: null # used only when weight_tying is True diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml index 0920ae0870e8..8ce009d5458f 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml +++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml @@ -119,7 +119,7 @@ model: use_checkpoint: False legacy: False use_flash_attention: True - enable_amp_o2_fp16: False + unet_precision: fp32 resblock_gn_groups: 32 first_stage_config: diff --git a/examples/multimodal/text_to_image/stable_diffusion/sd_train.py b/examples/multimodal/text_to_image/stable_diffusion/sd_train.py index 434150516d0c..968d9bec2884 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/sd_train.py +++ b/examples/multimodal/text_to_image/stable_diffusion/sd_train.py @@ -23,6 +23,7 @@ from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP from nemo.core.config import hydra_runner from nemo.utils import logging +from nemo.utils.callbacks import CUDAGraphCallback from nemo.utils.exp_manager import exp_manager @@ -56,12 +57,41 @@ def main(cfg) -> None: torch.backends.cuda.matmul.allow_tf32 = True - trainer = MegatronStableDiffusionTrainerBuilder(cfg).create_trainer() + callbacks = ( + None + if cfg.model.capture_cudagraph_iters < 0 + else [CUDAGraphCallback(capture_iteration=cfg.model.capture_cudagraph_iters)] + ) + trainer = MegatronStableDiffusionTrainerBuilder(cfg).create_trainer(callbacks) exp_manager(trainer, cfg.exp_manager) model = MegatronLatentDiffusion(cfg.model, trainer) + if cfg.model.capture_cudagraph_iters >= 0: + # Warmup the model with random data + with torch.cuda.stream(torch.cuda.Stream()): + n, c, h = cfg.model.micro_batch_size, cfg.model.channels, cfg.model.image_size + x = torch.randn((n, c, h, h), dtype=torch.float32, device="cuda") + t = torch.randint(77, (n,), device="cuda") + cc = torch.randn((n, 77, cfg.model.unet_config.context_dim), dtype=torch.float32, device="cuda",) + if cfg.model.precision in [16, '16']: + x = x.type(torch.float16) + cc = cc.type(torch.float16) + autocast_enabled = False + dgrad_dtype = torch.float16 + else: + autocast_enabled = True + dgrad_dtype = torch.float16 + + model = model.cuda() + for _ in range(5): + with torch.autocast(device_type="cuda", enabled=autocast_enabled, dtype=torch.float16): + out = model.model.model.diffusion_model(x, t, context=cc) + grad = torch.randn_like(out, dtype=dgrad_dtype) + out.backward(grad) + model.zero_grad() + if cfg.model.get('peft', None): peft_cfg_cls = PEFT_CONFIG_MAP[cfg.model.peft.peft_scheme] diff --git a/nemo/collections/multimodal/data/stable_diffusion/stable_diffusion_dataset.py b/nemo/collections/multimodal/data/stable_diffusion/stable_diffusion_dataset.py index 5929798267a5..61f1f9c91c01 100644 --- a/nemo/collections/multimodal/data/stable_diffusion/stable_diffusion_dataset.py +++ b/nemo/collections/multimodal/data/stable_diffusion/stable_diffusion_dataset.py @@ -36,18 +36,22 @@ def __init__( self.W = image_W self.image_key = image_key self.txt_key = txt_key - assert image_key.endswith('encoded') == txt_key.endswith( - 'encoded' - ), 'In precached mode, first and second stage key must both end with "encoded"' - self.precached = self.image_key.endswith('encoded') + img_precached = image_key.endswith('encoded') or image_key.endswith('moments') + txt_precached = txt_key.endswith('encoded') + assert ( + img_precached == txt_precached + ), 'First and second stage keys should enable/disable precache at the same time.' self.seq_len = seq_len self.context_dim = context_dim def __getitem__(self, index): item = {} - if self.precached: + if self.image_key.endswith('encoded'): item[self.image_key] = torch.randn(8, self.H // 8, self.W // 8) item[self.txt_key] = torch.randn(self.seq_len, self.context_dim) + elif self.image_key.endswith('moments'): + item[self.image_key] = torch.randn(1, 8, self.H // 8, self.W // 8) + item[self.txt_key] = torch.randn(self.seq_len, self.context_dim) else: item[self.image_key] = torch.randn(self.H, self.W, 3) item[self.txt_key] = f'This is meaningless fake text No.{index}' @@ -174,7 +178,7 @@ def transform_fn(sample): if data_cfg.get("validation") is not None and data_cfg.validation.get("data_path"): if data_cfg.get('synthetic_data', False): H, W = data_cfg.train.augmentations.center_crop_h_w.split(',') - train_data = SDSyntheticDataset( + val_data = SDSyntheticDataset( int(H), int(W), image_key=model_cfg.first_stage_key, @@ -212,24 +216,46 @@ def transform_fn(sample): # latents are of shape ([4, 64, 64]) return latents, text_embed - train_data = WebDatasetCommon( - dataset_cfg=data_cfg, - consumed_samples=consumed_samples, - map_fn=transform_fn, - compose_fn=tuple_to_dict, - is_train=True, - ) - - val_data = None - if data_cfg.get("validation") is not None and data_cfg.validation.get("data_path"): - val_data = WebDatasetCommon( + if data_cfg.get('synthetic_data', False): + H, W = data_cfg.train.augmentations.center_crop_h_w.split(',') + train_data = SDSyntheticDataset( + int(H), + int(W), + image_key=model_cfg.first_stage_key, + txt_key=model_cfg.cond_stage_key, + context_dim=model_cfg.unet_config.context_dim, + seq_len=77, + ) + else: + train_data = WebDatasetCommon( dataset_cfg=data_cfg, consumed_samples=consumed_samples, map_fn=transform_fn, compose_fn=tuple_to_dict, - is_train=False, + is_train=True, ) + val_data = None + if data_cfg.get("validation") is not None and data_cfg.validation.get("data_path"): + if data_cfg.get('synthetic_data', False): + H, W = data_cfg.train.augmentations.center_crop_h_w.split(',') + val_data = SDSyntheticDataset( + int(H), + int(W), + image_key=model_cfg.first_stage_key, + txt_key=model_cfg.cond_stage_key, + context_dim=model_cfg.unet_config.context_dim, + seq_len=77, + ) + else: + val_data = WebDatasetCommon( + dataset_cfg=data_cfg, + consumed_samples=consumed_samples, + map_fn=transform_fn, + compose_fn=tuple_to_dict, + is_train=False, + ) + return train_data, val_data diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py index 61bb664e43ed..33a194500a69 100644 --- a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py @@ -283,7 +283,6 @@ def init_from_ckpt( for k in keys: if k.startswith("cond_stage_model"): deleted += 1 - logging.info("Deleting ignored key {} from state_dict.".format(k)) del sd[k] logging.info(f"Deleted {deleted} keys from `cond_stage_model` state_dict.") @@ -294,7 +293,7 @@ def init_from_ckpt( if k.startswith("model.diffusion_model"): deleted += 1 del sd[k] - logging.info(f"Deleted {deleted} keys from `cond_stage_model` state_dict.") + logging.info(f"Deleted {deleted} keys from `model.diffusion_model` state_dict.") missing, unexpected = ( self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False) @@ -1675,18 +1674,21 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): # megatron_amp_O2 is not yet supported in diffusion models self.megatron_amp_O2 = cfg.get('megatron_amp_O2', False) + if self.cfg.precision in ['16', 16, 'bf16']: + self.model_parallel_config.enable_autocast = False + self.model = self.model_provider_func() self.conditioning_keys = [] - if self.trainer.precision in ['bf16', 'bf16-mixed']: + if self.model.precision in ['bf16', 'bf16-mixed']: self.autocast_dtype = torch.bfloat16 - elif self.trainer.precision in [32, '32', '32-true']: + elif self.model.precision in [32, '32', '32-true']: self.autocast_dtype = torch.float - elif self.trainer.precision in [16, '16', '16-mixed']: + elif self.model.precision in ['16-mixed', '16', 16]: self.autocast_dtype = torch.half else: - raise ValueError('precision must be in ["32-true", "16-mixed", "bf16-mixed"]') + raise ValueError('precision must be in [32, "32", "32-true", "16-mixed", "16", 16, "bf16-mixed", "bf16"]') self.log_train_loss = bool(int(os.getenv("NEMO_LOG_TRAIN_LOSS", 1))) self.loss_broadcast_src_rank = None @@ -1780,8 +1782,18 @@ def fwd_bwd_step(self, dataloader_iter, forward_only): return loss_mean, loss_dict - def training_step(self, dataloader_iter): + def training_step(self, batch): """ + Notice: `training_step` used to have the following signature to support pipeline + parallelism: + + def training_step(self, dataloader_iter, batch_idx): + + However, full iteration CUDA Graph callback is not compatible with this signature + right now, due to we need to wrap the dataloader to generate static tensor outside + the CUDA Graph. This signature moves `next(dataloader)` into the CUDA Graph + capturing region, thus we disabled it. + Our dataloaders produce a micro-batch and then we fetch a number of microbatches depending on the global batch size and model parallel size from the dataloader to produce a list of microbatches. @@ -1793,6 +1805,7 @@ def training_step(self, dataloader_iter): # we zero grads here because we also call backward in the megatron-core fwd/bwd functions self._optimizer.zero_grad() + dataloader_iter = iter([batch]) loss_mean, loss_dict = self.fwd_bwd_step(dataloader_iter, False) # when using sequence parallelism, the sequence parallel layernorm grads must be all-reduced @@ -1812,6 +1825,8 @@ def training_step(self, dataloader_iter): # async grad allreduce is not currently implemented for O1/autocasting mixed precision training # so we all-reduce gradients after the pipeline self.allreduce_gradients() # @sangkug we think this is causing memory to blow up (hurts perf) + else: + raise ValueError("Either distributed_fused_adam or megatron_amp_O2 needs to be set if ddp_overlap is set") # for cuda graph with pytorch lightning # these values will be used outside the capturing range @@ -1828,22 +1843,28 @@ def training_step(self, dataloader_iter): return loss_mean def non_cuda_graph_capturable(self): + # Moving CUDA metrics to CPU leads to sync, do not show on progress bar + # if CUDA graph is enabled. + show_metric = self.cfg.get("show_prog_bar_metric", True) and (self.cfg.get("capture_cudagraph_iters", -1) < 0) + if self.log_train_loss: - self.log('reduced_train_loss', self.loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1) + self.log('reduced_train_loss', self.loss_mean, prog_bar=show_metric, rank_zero_only=True, batch_size=1) if self.cfg.precision in [16, '16', '16-mixed']: loss_scale = self.trainer.precision_plugin.scaler._scale if loss_scale is not None: self.log('loss_scale', loss_scale, batch_size=1) - self.log_dict(self.loss_dict, prog_bar=False, logger=True, on_step=True, rank_zero_only=True, batch_size=1) + self.log_dict( + self.loss_dict, prog_bar=show_metric, logger=True, on_step=True, rank_zero_only=True, batch_size=1 + ) lr = self._optimizer.param_groups[0]['lr'] - self.log('lr', lr, prog_bar=True, rank_zero_only=True, batch_size=1) - self.log('global_step', self.trainer.global_step + 1, prog_bar=True, rank_zero_only=True, batch_size=1) + self.log('lr', lr, prog_bar=show_metric, rank_zero_only=True, batch_size=1) + self.log('global_step', self.trainer.global_step + 1, prog_bar=show_metric, rank_zero_only=True, batch_size=1) self.log( 'consumed_samples', self.compute_consumed_samples(self.trainer.global_step + 1 - self.init_global_step), - prog_bar=True, + prog_bar=show_metric, rank_zero_only=True, batch_size=1, ) @@ -1902,7 +1923,7 @@ def process_batch(batch): return [x, *c_list] def fwd_output_and_loss_func(dataloader_iter, model): - batch, _, _ = next(dataloader_iter) + batch = next(dataloader_iter) batch = process_batch(batch) batch = [x.cuda(non_blocking=True) for x in batch] if len(self.conditioning_keys) == 0: @@ -1991,7 +2012,7 @@ def build_train_valid_test_datasets(self): raise ValueError("limit_val_batches must be an integer or float less than or equal to 1.0.") if self.cfg.first_stage_key.endswith("encoded") or self.cfg.first_stage_key.endswith("moments"): - if self.cfg.cond_stage_key.endswith("precached_clip"): + if self.cfg.cond_stage_key.endswith("clip_encoded"): self._train_ds, self._validation_ds = build_train_valid_precached_clip_datasets( model_cfg=self.cfg, consumed_samples=self.compute_consumed_samples(0), ) @@ -2020,7 +2041,7 @@ def setup_training_data(self, cfg): logging.info( f'Setting up train dataloader with len(len(self._train_ds)): {len(self._train_ds)} and consumed samples: {consumed_samples}' ) - if self.cfg.cond_stage_key.endswith("precached_clip"): + if self.cfg.cond_stage_key.endswith("clip_encoded"): collate_fn = get_collate_fn( first_stage_key=self.cfg.first_stage_key, cond_stage_key=self.cfg.cond_stage_key, ) diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py index 91a214b90713..14560ba5d9d1 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py @@ -23,9 +23,11 @@ import torch.nn as nn import torch.nn.functional as F +from apex.contrib.group_norm import GroupNorm from nemo.collections.multimodal.modules.stable_diffusion.attention import SpatialTransformer from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import ( avg_pool_nd, + build_timestep_embedding, checkpoint, conv_nd, default, @@ -38,16 +40,26 @@ from nemo.utils import logging -def convert_module_to_dtype(module, dtype): +def convert_module_to_dtype(module, dtype, enable_norm_layers=False): # Convert module parameters to dtype if isinstance(module, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Linear)): module.weight.data = module.weight.data.to(dtype) if module.bias is not None: module.bias.data = module.bias.data.to(dtype) + if enable_norm_layers: + if isinstance(module, (nn.LayerNorm, nn.GroupNorm, GroupNorm)): + module.weight.data = module.weight.data.to(dtype) + if module.bias is not None: + module.bias.data = module.bias.data.to(dtype) -def convert_module_to_fp16(module): - convert_module_to_dtype(module, torch.float16) + +def convert_module_to_fp16(module, enable_norm_layers=False): + convert_module_to_dtype(module, torch.float16, enable_norm_layers) + + +def convert_module_to_fp32(module, enable_norm_layers=False): + convert_module_to_dtype(module, torch.float32, enable_norm_layers) class AttentionPool2d(nn.Module): @@ -538,8 +550,9 @@ def __init__( from_NeMo=False, # It must be specified when from pretrained is not None. It indicates loading unet from NeMo trained ckpt or HF use_flash_attention: bool = False, - enable_amp_o2_fp16: bool = False, + unet_precision: str = "fp32", lora_network_alpha=None, + timesteps=1000, ): super().__init__() from omegaconf.listconfig import ListConfig @@ -616,6 +629,10 @@ def __init__( linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ) + self.time_embeddings = torch.Tensor(build_timestep_embedding(model_channels, timesteps)).to('cuda') + if unet_precision == 'fp16-mixed' or unet_precision == 'fp16': + self.time_embeddings = self.time_embeddings.to(torch.float16) + if self.num_classes is not None: if isinstance(self.num_classes, int): self.label_emb = nn.Embedding(num_classes, time_embed_dim) @@ -787,6 +804,7 @@ def __init__( dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, + resblock_gn_groups=resblock_gn_groups, ) ] ch = model_channels * mult @@ -876,8 +894,12 @@ def __init__( logging.info(f"Missing keys: {missing_key}") logging.info(f"Unexpected keys: {unexpected_keys}") - if enable_amp_o2_fp16: + if unet_precision == "fp16-mixed": # AMP O2 self.convert_to_fp16() + elif unet_precision == 'fp16': + self.convert_to_fp16(enable_norm_layers=True) + + self.unet_precision = unet_precision def _input_blocks_mapping(self, input_dict): res_dict = {} @@ -1120,11 +1142,11 @@ def load(module: torch.nn.Module, prefix=""): return error_msgs - def convert_to_fp16(self): + def convert_to_fp16(self, enable_norm_layers=False): """ Convert the torso of the model to float16. """ - self.apply(convert_module_to_fp16) + self.apply(lambda module: convert_module_to_fp16(module=module, enable_norm_layers=enable_norm_layers)) def forward(self, x, timesteps=None, context=None, y=None, **kwargs): """ @@ -1145,7 +1167,13 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): self.num_classes is not None ), "must specify y if and only if the model is class-conditional" hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + + if self.unet_precision == "fp16-mixed" or self.unet_precision == "fp16": + x = x.type(torch.float16) + if context is not None: + context = context.type(torch.float16) + + t_emb = timestep_embedding(timesteps, self.model_channels, cached_embedding=self.time_embeddings) emb = self.time_embed(t_emb) if self.num_classes is not None: assert y.shape[0] == x.shape[0] diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/util.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/util.py index d22693a12801..3cf0e45e8e46 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/util.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/util.py @@ -177,7 +177,19 @@ def get_idx(end, device): return torch.arange(start=0, end=end, dtype=torch.float32, device=device) -def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): +def build_timestep_embedding(dim, max_timesteps, max_period=10000): + timesteps = np.arange(start=0, stop=max_timesteps, dtype=np.float32) + half = dim // 2 + idx = np.arange(start=0, stop=half, dtype=np.float32) + freqs = np.exp(-math.log(max_period) / half * idx) + args = timesteps[:, None] * freqs[None] + embedding = np.concatenate([np.cos(args), np.sin(args)], axis=-1) + if dim % 2: + embedding = np.concatenate([embedding, np.zeros_like(embedding[:, :1])], axis=-1) + return embedding + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, cached_embedding=None): """ Create sinusoidal timestep embeddings. @@ -193,13 +205,17 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): """ if not repeat_only: - half = dim // 2 - idx = get_idx(half, timesteps.device) - freqs = torch.exp(-math.log(max_period) / half * idx) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if cached_embedding is not None: + # using cached embedding and lookup in the cache + embedding = cached_embedding[timesteps, :] + else: + half = dim // 2 + idx = get_idx(half, timesteps.device) + freqs = torch.exp(-math.log(max_period) / half * idx) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) else: embedding = repeat(timesteps, "b -> b d", d=dim) return embedding diff --git a/nemo/collections/nlp/modules/common/megatron/clip_grads.py b/nemo/collections/nlp/modules/common/megatron/clip_grads.py index a643f878dc05..7edc6720574e 100644 --- a/nemo/collections/nlp/modules/common/megatron/clip_grads.py +++ b/nemo/collections/nlp/modules/common/megatron/clip_grads.py @@ -18,8 +18,8 @@ import torch from torch import inf -from nemo.collections.nlp.modules.common.megatron.module import param_is_not_shared from nemo.utils import logging +from nemo.utils.model_utils import param_is_not_shared try: import amp_C @@ -82,7 +82,6 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, use_fsdp=False): grads_for_norm = [] sharded_grads = [] sharded_grads_for_norm = [] - dummy_overflow_buf = torch.cuda.IntTensor([0]) for param in parameters: if param.grad is not None: @@ -110,7 +109,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, use_fsdp=False): # Norm parameters. max_norm = float(max_norm) norm_type = float(norm_type) - total_norm = 0.0 + total_norm = torch.zeros(1, device='cuda', dtype=torch.float32).squeeze() # Calculate norm. if norm_type == inf: @@ -118,23 +117,20 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, use_fsdp=False): total_norm = max(grad.abs().max() for grad in grads_for_norm) if not use_fsdp: - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) # Take max across all model-parallel GPUs. torch.distributed.all_reduce( - total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group() + total_norm, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group() ) else: if len(sharded_grads_for_norm) > 0: sharded_total_norm = max(grad.abs().max() for grad in sharded_grads_for_norm) total_norm = max(total_norm, sharded_total_norm) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) # Take max across both model-parallel and data-parallel GPUs. - torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX) - - total_norm = total_norm_cuda[0].item() + torch.distributed.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX) else: if norm_type == 2.0: + dummy_overflow_buf = torch.zeros(1, device='cuda', dtype=torch.int32).squeeze() # Use apex's multi-tensor applier for efficiency reasons. # Multi-tensor applier takes a function and a list of list # and performs the operation on that list all in one kernel. @@ -143,7 +139,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, use_fsdp=False): amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads_for_norm], False # no per-parameter norm ) else: - grad_norm = 0.0 + grad_norm = torch.zeros(1, device='cuda', dtype=torch.float32).squeeze() # Since we will be summing across data parallel groups, # we need the pow(norm-type). total_norm = grad_norm ** norm_type @@ -153,7 +149,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, use_fsdp=False): amp_C.multi_tensor_l2norm, dummy_overflow_buf.fill_(0), [sharded_grads_for_norm], False ) else: - sharded_grad_norm = 0.0 + sharded_grad_norm = torch.zeros(1, device='cuda', dtype=torch.float32).squeeze() total_sharded_norm = sharded_grad_norm ** norm_type else: for grad in grads_for_norm: @@ -164,29 +160,25 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, use_fsdp=False): grad_norm = torch.norm(grad, norm_type) total_sharded_norm += grad_norm ** norm_type - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) if use_fsdp: - total_sharded_norm_cuda = torch.cuda.FloatTensor([float(total_sharded_norm)]) # Sum norm of grad shards across data-parallel GPUs. torch.distributed.all_reduce( - total_sharded_norm_cuda, - op=torch.distributed.ReduceOp.SUM, - group=parallel_state.get_data_parallel_group(), + total_sharded_norm, op=torch.distributed.ReduceOp.SUM, group=parallel_state.get_data_parallel_group(), ) - total_norm_cuda += total_sharded_norm_cuda + total_norm += total_sharded_norm.squeeze() + # Sum across all model-parallel GPUs. torch.distributed.all_reduce( - total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=parallel_state.get_model_parallel_group() + total_norm, op=torch.distributed.ReduceOp.SUM, group=parallel_state.get_model_parallel_group() ) - total_norm = total_norm_cuda[0].item() total_norm = total_norm ** (1.0 / norm_type) # Scale. clip_coeff = max_norm / (total_norm + 1.0e-6) - if clip_coeff < 1.0: - if len(grads) > 0 or len(sharded_grads) > 0: # (@adithyare) grads can be empty for adapter training. - grads += sharded_grads - multi_tensor_applier(amp_C.multi_tensor_scale, dummy_overflow_buf.fill_(0), [grads, grads], clip_coeff) + clip_coeff_clamped = torch.clamp(clip_coeff, max=1.0) + if len(grads) > 0 or len(sharded_grads) > 0: # (@adithyare) grads can be empty for adapter training. + grads += sharded_grads + torch._foreach_mul_(grads, clip_coeff_clamped.squeeze()) return total_norm diff --git a/nemo/collections/nlp/modules/common/megatron/module.py b/nemo/collections/nlp/modules/common/megatron/module.py index 88d201d10001..ccd485427c3c 100644 --- a/nemo/collections/nlp/modules/common/megatron/module.py +++ b/nemo/collections/nlp/modules/common/megatron/module.py @@ -38,10 +38,6 @@ _BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor) -def param_is_not_shared(param): - return not hasattr(param, 'shared') or not param.shared - - class MegatronModule(torch.nn.Module): """Megatron specific extensions of torch Module with support for pipelining.""" diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index f77eb7e25813..91f1fab348da 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -53,6 +53,12 @@ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.nn.parallel import DistributedDataParallel +try: + from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state +except ImportError: + # since PyTorch 2.3 the path has changed + from torch.amp.grad_scaler import _refresh_per_optimizer_state + from nemo.collections.multimodal.modules.stable_diffusion.attention import BasicTransformerBlock from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.collections.nlp.modules.common.megatron.transformer import AutocastTransformerLayer, ParallelTransformerLayer @@ -1161,16 +1167,26 @@ class PipelineMixedPrecisionPlugin(MixedPrecisionPlugin): def __init__( self, - precision: Literal["16-mixed", "bf16-mixed"], + precision: Literal["16-mixed", "bf16-mixed", '16', 'bf16', 16], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None, ) -> None: - super().__init__(precision, device, scaler=scaler) - dtype = None # MixedPrecisionPlugin class in PTL >= 2.0 takes only "16-mixed" or "bf16-mixed" for precision arg - if precision == '16-mixed': + if precision in ['16-mixed', '16', 16]: + plugin_precision = '16-mixed' + elif precision in ['bf16-mixed', 'bf16']: + plugin_precision = 'bf16-mixed' + else: + raise RuntimeError( + "precision expected to be one of: " + "['16-mixed', '16', 16, 'bf16-mixed', 'bf16']" + f" but {precision} found" + ) + super().__init__(plugin_precision, device, scaler=scaler) + dtype = None + if precision in ['16-mixed', '16', 16]: dtype = torch.float16 - elif precision == 'bf16-mixed': + elif precision in ['bf16-mixed', 'bf16']: dtype = torch.bfloat16 torch.set_autocast_gpu_dtype(dtype) @@ -1331,7 +1347,7 @@ def update(self, new_scale=None): self._hysteresis_tracker = self.hysteresis # To prepare for next iteration, clear the data collected from optimizers this iteration. - self._per_optimizer_states = defaultdict(torch.cuda.amp.grad_scaler._refresh_per_optimizer_state) + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) def state_dict(self): """ diff --git a/nemo/core/optim/megatron_fused_adam.py b/nemo/core/optim/megatron_fused_adam.py index 9278f0a134ef..9035a393c8e8 100755 --- a/nemo/core/optim/megatron_fused_adam.py +++ b/nemo/core/optim/megatron_fused_adam.py @@ -14,7 +14,8 @@ import amp_C import torch -from nemo.collections.nlp.modules.common.megatron.module import param_is_not_shared + +from nemo.utils.model_utils import param_is_not_shared try: from megatron.core import parallel_state @@ -118,7 +119,7 @@ def step(self, closure=None, grad_scaler=None): False, ) else: - fp32_grad_norm = torch.tensor([0.0], dtype=torch.float32, device=device) + fp32_grad_norm = torch.zeros(1, dtype=torch.float32, device=device) if fp16_grads_for_norm: fp16_grad_norm, _ = multi_tensor_applier( @@ -129,7 +130,7 @@ def step(self, closure=None, grad_scaler=None): False, ) else: - fp16_grad_norm = torch.tensor([0.0], dtype=torch.float32, device=device) + fp16_grad_norm = torch.zeros(1, dtype=torch.float32, device=device) # Prep L2 norm for allreduce total_norm = (fp32_grad_norm ** self.norm_type + fp16_grad_norm ** self.norm_type).squeeze() diff --git a/nemo/utils/callbacks/cuda_graph.py b/nemo/utils/callbacks/cuda_graph.py index ba6046b79850..247c67856c7b 100644 --- a/nemo/utils/callbacks/cuda_graph.py +++ b/nemo/utils/callbacks/cuda_graph.py @@ -12,6 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +# CUDAGraphCallback is a full iteration CUDA graph callback designed for +# models with PyTorch Lightning first, this has been tested with Stable +# Diffusion right now. +# +# Prerequisites for this callback: +# 1. Capturable: user has to make sure (almost) all the host & device +# synchronizations are removed, some of the syncs regarding logging +# of metrics introduced by PyTorch Lightning itself have been removed +# by this callback. This ensures the graph can be captured. +# 2. Topology: user has to make sure there's no dynamic control flow +# within the iteration. Please use APEX alternatives for building +# blocks that contain dynamic control flow, e.g. gradient clipping. +# Otherwise the captured graph can run, but may raise silent failure, +# e.g. NaN loss. +# 3. Parameters: user has to make sure pointers involved in the graph +# capturing range don't change across iterations. In this case users +# have to ensure that data is copied to static tensors. Otherwise this +# can also lead to silent failure. + import os import time from dataclasses import dataclass @@ -20,9 +39,11 @@ import pytorch_lightning as pl import torch +from pytorch_lightning import LightningModule from pytorch_lightning.callbacks import Callback from pytorch_lightning.loops.optimization.automatic import ClosureResult -from pytorch_lightning.utilities.rank_zero import rank_zero_info +from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection, _ResultMetric +from pytorch_lightning.utilities import CombinedLoader, rank_zero_info from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import STEP_OUTPUT from torch.nn.parallel import DistributedDataParallel @@ -104,6 +125,39 @@ def zero_grad(optimizer, *args, **kwargs): optimizer.__orig_zero_grad__(*args, **kwargs) +def to_tensor(self, value, name): + # Log metrics in PyTorch Lightning often invokes CPU & GPU synchronizations. Here + # we implement smart metrics to avoid those synchronizations. + # Refer to: https://github.com/Lightning-AI/pytorch-lightning/blob/2.0.7/src/lightning/pytorch/core/module.py#L615 + value = value.clone().detach() if isinstance(value, torch.Tensor) else torch.tensor(value) + if not torch.numel(value) == 1: + raise ValueError( + f"`self.log({name}, {value})` was called, but the tensor must have a single element." + f" You can try doing `self.log({name}, {value}.mean())`" + ) + value = value.squeeze() + return value + + +def register_key(self, key, meta, value): + # PyTorch Lightning creates all metrics on GPU, but creating the metric on + # its input device is prefered. + # Refer to: https://github.com/Lightning-AI/pytorch-lightning/blob/2.0.7/src/lightning/pytorch/trainer/connectors/logger_connector/result.py#L409 + metric = _ResultMetric(meta, isinstance(value, torch.Tensor)) + device = value.device if isinstance(value, torch.Tensor) else self.device + metric = metric.to(device) + self[key] = metric + + +def update_metrics(self, key, value, batch_size): + # PyTorch Lightning always move all metrics to GPU, but moving the metric to + # its input device is prefered. + result_metric = self[key] + device = value.device if isinstance(value, torch.Tensor) else self.device + result_metric.forward(value.to(device), batch_size) + result_metric.has_reset = False + + def get_optimizer_step(state): def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure=None,) -> None: # Not all optimizer supports set_to_none. @@ -131,7 +185,7 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure=None,) - # Sleep for one second to let environment stable time.sleep(1) rank_zero_info("CUDAGraphCallback: capturing CUDA graph for module %s.", self.__class__.__name__) - with torch.cuda.graph(state.graph, stream=state.stream): + with torch.cuda.graph(state.graph, stream=state.stream, capture_error_mode="global"): self.__orig_optimizer_step__( epoch, batch_idx, optimizer, optimizer_closure=optimizer_closure, ) @@ -152,8 +206,8 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure=None,) - def get_training_step(state): - def training_step(self, batch, batch_idx): - results = self.__orig_training_step__(batch, batch_idx) + def training_step(self, batch): + results = self.__orig_training_step__(batch) if state.output is None: state.output = struct_copy_one(results) @@ -246,7 +300,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - if self.state.capture_iteration < 0: return - if is_param_in_hook_signature(pl_module, "dataloader_iter", explicit=True): + if is_param_in_hook_signature(pl_module.training_step, "dataloader_iter", explicit=True): raise Exception( "Found `dataloader_iter` argument in the `training_step`. This is " "not supported by full iteration CUDA graph capturing yet since " @@ -270,13 +324,17 @@ def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") return # Ensure training dataloader loads data to static buffer - dataloader = trainer.train_dataloader + dataloader = trainer.fit_loop._combined_loader._iterables assert isinstance( dataloader, torch.utils.data.dataloader.DataLoader ), f"Expect Dataloader type but got {type(dataloader)}" - trainer.train_dataloader.__orig_dataloader__ = dataloader static_loader = StaticBufferLoader(dataloader) - trainer.train_dataloader.loaders = static_loader + _mode = trainer.fit_loop._combined_loader._mode + combined_loader = CombinedLoader(static_loader, mode=_mode) + trainer.fit_loop.__orig_combined_loader__ = trainer.fit_loop._combined_loader + trainer.fit_loop._combined_loader = combined_loader + trainer.fit_loop._data_fetcher.setup(trainer.fit_loop._combined_loader) + iter(trainer.fit_loop._data_fetcher) # Warn if `optimizer.zero_grad()` invoked during graph capturing for optimizer in trainer.optimizers: @@ -290,10 +348,18 @@ def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") for config in trainer.lr_scheduler_configs: assert isinstance( config.scheduler, torch.optim.lr_scheduler._LRScheduler - ), f"Expect _LRScheduler type but got {type(dataloader)}" + ), f"Expect _LRScheduler type but got {type(config.scheduler)}" config.scheduler.__orig_get_lr__ = config.scheduler.get_lr config.scheduler.get_lr = MethodType(get_lr, config.scheduler) + # Use smart metrics to avoid syncs + LightningModule.__orig_to_tensor__ = LightningModule._LightningModule__to_tensor + LightningModule._LightningModule__to_tensor = to_tensor + _ResultCollection.__orig_register_key__ = _ResultCollection.register_key + _ResultCollection.register_key = register_key + _ResultCollection.__orig_update_metrics__ = _ResultCollection.update_metrics + _ResultCollection.update_metrics = update_metrics + # Save model outputs to static buffer for PL states reconstruct pl_module.__orig_training_step__ = pl_module.training_step training_step = get_training_step(self.state) @@ -309,9 +375,10 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - if self.state.capture_iteration < 0: return - dataloader = trainer.train_dataloader.__orig_dataloader__ - trainer.train_dataloader.loaders = dataloader - del trainer.train_dataloader.__orig_dataloader__ + trainer.fit_loop._combined_loader = trainer.fit_loop.__orig_combined_loader__ + trainer.fit_loop._data_fetcher.setup(trainer.fit_loop._combined_loader) + iter(trainer.fit_loop._data_fetcher) + del trainer.fit_loop.__orig_combined_loader__ for optimizer in trainer.optimizers: optimizer.zero_grad = optimizer.__orig_zero_grad__ @@ -321,6 +388,13 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - config.scheduler.get_lr = config.scheduler.__orig_get_lr__ del config.scheduler.__orig_get_lr__ + LightningModule._LightningModule__to_tensor = LightningModule.__orig_to_tensor__ + del LightningModule.__orig_to_tensor__ + _ResultCollection.register_key = _ResultCollection.__orig_register_key__ + del _ResultCollection.__orig_register_key__ + _ResultCollection.update_metrics = _ResultCollection.__orig_update_metrics__ + del _ResultCollection.__orig_update_metrics__ + pl_module.training_step = pl_module.__orig_training_step__ del pl_module.__orig_training_step__ diff --git a/nemo/utils/model_utils.py b/nemo/utils/model_utils.py index c7047d4e3b52..680e7a723262 100644 --- a/nemo/utils/model_utils.py +++ b/nemo/utils/model_utils.py @@ -79,6 +79,10 @@ def load_config(model_file: str) -> DictConfig: return model_config +def param_is_not_shared(param): + return not hasattr(param, 'shared') or not param.shared + + def resolve_dataset_name_from_cfg(cfg: 'DictConfig') -> Optional[str]: """ Parses items of the provided sub-config to find the first potential key that diff --git a/scripts/nlp_language_modeling/merge_lora_weights/merge.py b/scripts/nlp_language_modeling/merge_lora_weights/merge.py index b3d7ca81a674..ccdb433630da 100644 --- a/scripts/nlp_language_modeling/merge_lora_weights/merge.py +++ b/scripts/nlp_language_modeling/merge_lora_weights/merge.py @@ -175,11 +175,7 @@ def main(cfg) -> None: # trainer required for restoring model parallel models trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer) - if ( - cfg.tensor_model_parallel_size < 0 - or cfg.pipeline_model_parallel_size < 0 - or cfg.get('pipeline_model_parallel_split_rank', -1) < 0 - ): + if cfg.tensor_model_parallel_size < 0 or cfg.pipeline_model_parallel_size < 0: model_config = MegatronGPTModel.restore_from( restore_path=cfg.gpt_model_file, trainer=trainer, return_config=True, ) @@ -187,7 +183,6 @@ def main(cfg) -> None: with open_dict(cfg): cfg.tensor_model_parallel_size = model_config.get('tensor_model_parallel_size', 1) cfg.pipeline_model_parallel_size = model_config.get('pipeline_model_parallel_size', 1) - cfg.pipeline_model_parallel_split_rank = model_config.get('pipeline_model_parallel_split_rank', 0) if cfg.gpt_model_file: save_restore_connector = NLPSaveRestoreConnector() @@ -207,7 +202,6 @@ def main(cfg) -> None: pretrained_cfg.activations_checkpoint_method = None pretrained_cfg.precision = trainer.precision pretrained_cfg.use_cpu_initialization = cfg.trainer.accelerator == 'cpu' - pretrained_cfg["apply_rope_fusion"] = False model = MegatronGPTModel.restore_from( restore_path=cfg.gpt_model_file, trainer=trainer, @@ -226,14 +220,12 @@ def main(cfg) -> None: app_state.pipeline_model_parallel_rank, app_state.model_parallel_size, app_state.data_parallel_size, - app_state.pipeline_model_parallel_split_rank, app_state.virtual_pipeline_model_parallel_rank, ) = fake_initialize_model_parallel( world_size=app_state.model_parallel_size, rank=trainer.global_rank, tensor_model_parallel_size_=cfg.tensor_model_parallel_size, pipeline_model_parallel_size_=cfg.pipeline_model_parallel_size, - pipeline_model_parallel_split_rank_=cfg.pipeline_model_parallel_split_rank, ) checkpoint_path = inject_model_parallel_rank(os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name)) model = MegatronGPTModel.load_from_checkpoint(checkpoint_path, hparams_file=cfg.hparams_file, trainer=trainer) diff --git a/tests/core/test_optimizers_schedulers.py b/tests/core/test_optimizers_schedulers.py index 2b185cbe476d..5e5d1ee20c83 100644 --- a/tests/core/test_optimizers_schedulers.py +++ b/tests/core/test_optimizers_schedulers.py @@ -143,7 +143,7 @@ def test_get_optimizer(self): model.cuda() for opt_name in AVAILABLE_OPTIMIZERS.keys(): - if opt_name == 'fused_adam': + if opt_name == 'fused_adam' or opt_name == 'megatron_fused_adam': if not torch.cuda.is_available(): continue if opt_name == 'distributed_fused_adam':