Skip to content

Commit

Permalink
[SD] CUDA Graphs update (NVIDIA#8613)
Browse files Browse the repository at this point in the history
* [SD] remove synchronizations

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* Typo in logging

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* [SD] Remove the sync invoked by tensor allocation.

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* Make the model sync-free again.

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* Support PyTorch Lightning 2 for full iteration CUDA graph callback.

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* Add documentation about CUDAGraphCallback.

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* Support synthetic dataset.

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* Fix typo.

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* Fix the bug of wrong GN groups.

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* remove circular dependency

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* Change naming for offline clip

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* Add exception when no gradient allreduce is called.

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* rename enable_amp_o2_fp16 -> unet_precision

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* Adjustments to PyTorch 2.3

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* fix CUDA Graphs support in SD

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* Document incompatibility betwee pipe parallelism and full iteration CUDA Graph callback for SD.

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* update CUDA Graphs callback to PTL 2.1

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* [SD] Full-fp16: push normalization layers in FP16.

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* [SD] enable CUDA Graphs in examples

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* [SD] add model warmup

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* fix sanity-check for CUDA Graphs

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* [SD] CUDA Graphs test

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* Update cuda graph jenkins test

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* fix typo

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* fix path in test

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* handle unexpected precision value for PipelineMixedPrecisionPlugin

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* remove unused import

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* replace unsupported syntax

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* typo

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* Add a gurad for megatron fused adam

Signed-off-by: Mingyuan Ma <mingyuanm@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix bugs for FSDP in clip_grads

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

* [SD] skip model warmup when CUDA Graph not captured

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>

---------

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>
Signed-off-by: Mingyuan Ma <mingyuanm@nvidia.com>
Signed-off-by: Marek Wawrzos <marek.28.93@gmail.com>
Co-authored-by: Szymon Mikler <smikler@nvidia.com>
Co-authored-by: Wil Kong <alpha0422@gmail.com>
Co-authored-by: Mengdi Wang <didow@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ming <111467530+Victor49152@users.noreply.github.com>
Co-authored-by: Mingyuan Ma <mingyuanm@nvidia.com>
  • Loading branch information
7 people authored Apr 9, 2024
1 parent 23baa48 commit c573826
Show file tree
Hide file tree
Showing 15 changed files with 349 additions and 102 deletions.
43 changes: 43 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,53 @@ pipeline {
model.unet_config.use_flash_attention=False \
model.unet_config.attention_resolutions=[1] \
model.unet_config.channel_mult=[1] \
model.ddp_overlap=False \
"
sh "rm -rf /home/TestData/multimodal/stable_diffusion_train"
}
}
stage('L2: Multimodal Stable Diffusion Train with Cuda Graph') {
when {
anyOf {
branch 'main'
changeRequest target: 'main'
}
}
failFast true
steps {
sh "rm -rf /home/TestData/multimodal/stable_diffusion_train_with_cuda_graphs"
sh "python examples/multimodal/text_to_image/stable_diffusion/sd_train.py \
trainer.precision=16 \
trainer.num_nodes=1 \
trainer.devices=1 \
++exp_manager.max_time_per_run=00:00:03:00 \
exp_manager.exp_dir=/home/TestData/multimodal/stable_diffusion_train_with_cuda_graph \
trainer.max_steps=20 \
model.micro_batch_size=1 \
model.global_batch_size=1 \
model.data.synthetic_data=True \
model.first_stage_key=images_moments \
model.cond_stage_key=clip_encoded \
model.optim.name=megatron_fused_adam \
+model.optim.capturable=True \
exp_manager.ema.enable=False \
model.cond_stage_config._target_=nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder \
++model.cond_stage_config.version=openai/clip-vit-large-patch14 \
++model.cond_stage_config.max_length=77 \
model.inductor=False \
~model.cond_stage_config.restore_from_path \
~model.cond_stage_config.freeze \
~model.cond_stage_config.layer \
model.first_stage_config.from_pretrained=null \
model.ddp_overlap=False \
model.capture_cudagraph_iters=15 \
model.unet_config.use_flash_attention=False \
model.unet_config.attention_resolutions=[1] \
model.unet_config.channel_mult=[1] \
"
sh "rm -rf /home/TestData/multimodal/stable_diffusion_train_with_cuda_graphs"
}
}
// stage('L2: Multimodal ControlNet Train') {
// when {
// anyOf {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ model:
use_checkpoint: False
legacy: False
use_flash_attention: True
enable_amp_o2_fp16: False
unet_precision: fp32
resblock_gn_groups: 32
lora_network_alpha: null

Expand Down Expand Up @@ -214,4 +214,4 @@ model:
row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal
layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers
weight_tying: False
position_embedding_strategy: null # used only when weight_tying is True
position_embedding_strategy: null # used only when weight_tying is True
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ model:
use_checkpoint: False
legacy: False
use_flash_attention: True
enable_amp_o2_fp16: False
unet_precision: fp32
resblock_gn_groups: 32

first_stage_config:
Expand Down
32 changes: 31 additions & 1 deletion examples/multimodal/text_to_image/stable_diffusion/sd_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.callbacks import CUDAGraphCallback
from nemo.utils.exp_manager import exp_manager


Expand Down Expand Up @@ -56,12 +57,41 @@ def main(cfg) -> None:

torch.backends.cuda.matmul.allow_tf32 = True

trainer = MegatronStableDiffusionTrainerBuilder(cfg).create_trainer()
callbacks = (
None
if cfg.model.capture_cudagraph_iters < 0
else [CUDAGraphCallback(capture_iteration=cfg.model.capture_cudagraph_iters)]
)
trainer = MegatronStableDiffusionTrainerBuilder(cfg).create_trainer(callbacks)

exp_manager(trainer, cfg.exp_manager)

model = MegatronLatentDiffusion(cfg.model, trainer)

if cfg.model.capture_cudagraph_iters >= 0:
# Warmup the model with random data
with torch.cuda.stream(torch.cuda.Stream()):
n, c, h = cfg.model.micro_batch_size, cfg.model.channels, cfg.model.image_size
x = torch.randn((n, c, h, h), dtype=torch.float32, device="cuda")
t = torch.randint(77, (n,), device="cuda")
cc = torch.randn((n, 77, cfg.model.unet_config.context_dim), dtype=torch.float32, device="cuda",)
if cfg.model.precision in [16, '16']:
x = x.type(torch.float16)
cc = cc.type(torch.float16)
autocast_enabled = False
dgrad_dtype = torch.float16
else:
autocast_enabled = True
dgrad_dtype = torch.float16

model = model.cuda()
for _ in range(5):
with torch.autocast(device_type="cuda", enabled=autocast_enabled, dtype=torch.float16):
out = model.model.model.diffusion_model(x, t, context=cc)
grad = torch.randn_like(out, dtype=dgrad_dtype)
out.backward(grad)
model.zero_grad()

if cfg.model.get('peft', None):

peft_cfg_cls = PEFT_CONFIG_MAP[cfg.model.peft.peft_scheme]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,22 @@ def __init__(
self.W = image_W
self.image_key = image_key
self.txt_key = txt_key
assert image_key.endswith('encoded') == txt_key.endswith(
'encoded'
), 'In precached mode, first and second stage key must both end with "encoded"'
self.precached = self.image_key.endswith('encoded')
img_precached = image_key.endswith('encoded') or image_key.endswith('moments')
txt_precached = txt_key.endswith('encoded')
assert (
img_precached == txt_precached
), 'First and second stage keys should enable/disable precache at the same time.'
self.seq_len = seq_len
self.context_dim = context_dim

def __getitem__(self, index):
item = {}
if self.precached:
if self.image_key.endswith('encoded'):
item[self.image_key] = torch.randn(8, self.H // 8, self.W // 8)
item[self.txt_key] = torch.randn(self.seq_len, self.context_dim)
elif self.image_key.endswith('moments'):
item[self.image_key] = torch.randn(1, 8, self.H // 8, self.W // 8)
item[self.txt_key] = torch.randn(self.seq_len, self.context_dim)
else:
item[self.image_key] = torch.randn(self.H, self.W, 3)
item[self.txt_key] = f'This is meaningless fake text No.{index}'
Expand Down Expand Up @@ -174,7 +178,7 @@ def transform_fn(sample):
if data_cfg.get("validation") is not None and data_cfg.validation.get("data_path"):
if data_cfg.get('synthetic_data', False):
H, W = data_cfg.train.augmentations.center_crop_h_w.split(',')
train_data = SDSyntheticDataset(
val_data = SDSyntheticDataset(
int(H),
int(W),
image_key=model_cfg.first_stage_key,
Expand Down Expand Up @@ -212,24 +216,46 @@ def transform_fn(sample):
# latents are of shape ([4, 64, 64])
return latents, text_embed

train_data = WebDatasetCommon(
dataset_cfg=data_cfg,
consumed_samples=consumed_samples,
map_fn=transform_fn,
compose_fn=tuple_to_dict,
is_train=True,
)

val_data = None
if data_cfg.get("validation") is not None and data_cfg.validation.get("data_path"):
val_data = WebDatasetCommon(
if data_cfg.get('synthetic_data', False):
H, W = data_cfg.train.augmentations.center_crop_h_w.split(',')
train_data = SDSyntheticDataset(
int(H),
int(W),
image_key=model_cfg.first_stage_key,
txt_key=model_cfg.cond_stage_key,
context_dim=model_cfg.unet_config.context_dim,
seq_len=77,
)
else:
train_data = WebDatasetCommon(
dataset_cfg=data_cfg,
consumed_samples=consumed_samples,
map_fn=transform_fn,
compose_fn=tuple_to_dict,
is_train=False,
is_train=True,
)

val_data = None
if data_cfg.get("validation") is not None and data_cfg.validation.get("data_path"):
if data_cfg.get('synthetic_data', False):
H, W = data_cfg.train.augmentations.center_crop_h_w.split(',')
val_data = SDSyntheticDataset(
int(H),
int(W),
image_key=model_cfg.first_stage_key,
txt_key=model_cfg.cond_stage_key,
context_dim=model_cfg.unet_config.context_dim,
seq_len=77,
)
else:
val_data = WebDatasetCommon(
dataset_cfg=data_cfg,
consumed_samples=consumed_samples,
map_fn=transform_fn,
compose_fn=tuple_to_dict,
is_train=False,
)

return train_data, val_data


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,6 @@ def init_from_ckpt(
for k in keys:
if k.startswith("cond_stage_model"):
deleted += 1
logging.info("Deleting ignored key {} from state_dict.".format(k))
del sd[k]
logging.info(f"Deleted {deleted} keys from `cond_stage_model` state_dict.")

Expand All @@ -294,7 +293,7 @@ def init_from_ckpt(
if k.startswith("model.diffusion_model"):
deleted += 1
del sd[k]
logging.info(f"Deleted {deleted} keys from `cond_stage_model` state_dict.")
logging.info(f"Deleted {deleted} keys from `model.diffusion_model` state_dict.")

missing, unexpected = (
self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False)
Expand Down Expand Up @@ -1675,18 +1674,21 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
# megatron_amp_O2 is not yet supported in diffusion models
self.megatron_amp_O2 = cfg.get('megatron_amp_O2', False)

if self.cfg.precision in ['16', 16, 'bf16']:
self.model_parallel_config.enable_autocast = False

self.model = self.model_provider_func()

self.conditioning_keys = []

if self.trainer.precision in ['bf16', 'bf16-mixed']:
if self.model.precision in ['bf16', 'bf16-mixed']:
self.autocast_dtype = torch.bfloat16
elif self.trainer.precision in [32, '32', '32-true']:
elif self.model.precision in [32, '32', '32-true']:
self.autocast_dtype = torch.float
elif self.trainer.precision in [16, '16', '16-mixed']:
elif self.model.precision in ['16-mixed', '16', 16]:
self.autocast_dtype = torch.half
else:
raise ValueError('precision must be in ["32-true", "16-mixed", "bf16-mixed"]')
raise ValueError('precision must be in [32, "32", "32-true", "16-mixed", "16", 16, "bf16-mixed", "bf16"]')

self.log_train_loss = bool(int(os.getenv("NEMO_LOG_TRAIN_LOSS", 1)))
self.loss_broadcast_src_rank = None
Expand Down Expand Up @@ -1780,8 +1782,18 @@ def fwd_bwd_step(self, dataloader_iter, forward_only):

return loss_mean, loss_dict

def training_step(self, dataloader_iter):
def training_step(self, batch):
"""
Notice: `training_step` used to have the following signature to support pipeline
parallelism:
def training_step(self, dataloader_iter, batch_idx):
However, full iteration CUDA Graph callback is not compatible with this signature
right now, due to we need to wrap the dataloader to generate static tensor outside
the CUDA Graph. This signature moves `next(dataloader)` into the CUDA Graph
capturing region, thus we disabled it.
Our dataloaders produce a micro-batch and then we fetch
a number of microbatches depending on the global batch size and model parallel size
from the dataloader to produce a list of microbatches.
Expand All @@ -1793,6 +1805,7 @@ def training_step(self, dataloader_iter):
# we zero grads here because we also call backward in the megatron-core fwd/bwd functions
self._optimizer.zero_grad()

dataloader_iter = iter([batch])
loss_mean, loss_dict = self.fwd_bwd_step(dataloader_iter, False)

# when using sequence parallelism, the sequence parallel layernorm grads must be all-reduced
Expand All @@ -1812,6 +1825,8 @@ def training_step(self, dataloader_iter):
# async grad allreduce is not currently implemented for O1/autocasting mixed precision training
# so we all-reduce gradients after the pipeline
self.allreduce_gradients() # @sangkug we think this is causing memory to blow up (hurts perf)
else:
raise ValueError("Either distributed_fused_adam or megatron_amp_O2 needs to be set if ddp_overlap is set")

# for cuda graph with pytorch lightning
# these values will be used outside the capturing range
Expand All @@ -1828,22 +1843,28 @@ def training_step(self, dataloader_iter):
return loss_mean

def non_cuda_graph_capturable(self):
# Moving CUDA metrics to CPU leads to sync, do not show on progress bar
# if CUDA graph is enabled.
show_metric = self.cfg.get("show_prog_bar_metric", True) and (self.cfg.get("capture_cudagraph_iters", -1) < 0)

if self.log_train_loss:
self.log('reduced_train_loss', self.loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1)
self.log('reduced_train_loss', self.loss_mean, prog_bar=show_metric, rank_zero_only=True, batch_size=1)

if self.cfg.precision in [16, '16', '16-mixed']:
loss_scale = self.trainer.precision_plugin.scaler._scale
if loss_scale is not None:
self.log('loss_scale', loss_scale, batch_size=1)

self.log_dict(self.loss_dict, prog_bar=False, logger=True, on_step=True, rank_zero_only=True, batch_size=1)
self.log_dict(
self.loss_dict, prog_bar=show_metric, logger=True, on_step=True, rank_zero_only=True, batch_size=1
)
lr = self._optimizer.param_groups[0]['lr']
self.log('lr', lr, prog_bar=True, rank_zero_only=True, batch_size=1)
self.log('global_step', self.trainer.global_step + 1, prog_bar=True, rank_zero_only=True, batch_size=1)
self.log('lr', lr, prog_bar=show_metric, rank_zero_only=True, batch_size=1)
self.log('global_step', self.trainer.global_step + 1, prog_bar=show_metric, rank_zero_only=True, batch_size=1)
self.log(
'consumed_samples',
self.compute_consumed_samples(self.trainer.global_step + 1 - self.init_global_step),
prog_bar=True,
prog_bar=show_metric,
rank_zero_only=True,
batch_size=1,
)
Expand Down Expand Up @@ -1902,7 +1923,7 @@ def process_batch(batch):
return [x, *c_list]

def fwd_output_and_loss_func(dataloader_iter, model):
batch, _, _ = next(dataloader_iter)
batch = next(dataloader_iter)
batch = process_batch(batch)
batch = [x.cuda(non_blocking=True) for x in batch]
if len(self.conditioning_keys) == 0:
Expand Down Expand Up @@ -1991,7 +2012,7 @@ def build_train_valid_test_datasets(self):
raise ValueError("limit_val_batches must be an integer or float less than or equal to 1.0.")

if self.cfg.first_stage_key.endswith("encoded") or self.cfg.first_stage_key.endswith("moments"):
if self.cfg.cond_stage_key.endswith("precached_clip"):
if self.cfg.cond_stage_key.endswith("clip_encoded"):
self._train_ds, self._validation_ds = build_train_valid_precached_clip_datasets(
model_cfg=self.cfg, consumed_samples=self.compute_consumed_samples(0),
)
Expand Down Expand Up @@ -2020,7 +2041,7 @@ def setup_training_data(self, cfg):
logging.info(
f'Setting up train dataloader with len(len(self._train_ds)): {len(self._train_ds)} and consumed samples: {consumed_samples}'
)
if self.cfg.cond_stage_key.endswith("precached_clip"):
if self.cfg.cond_stage_key.endswith("clip_encoded"):
collate_fn = get_collate_fn(
first_stage_key=self.cfg.first_stage_key, cond_stage_key=self.cfg.cond_stage_key,
)
Expand Down
Loading

0 comments on commit c573826

Please sign in to comment.