Skip to content

Commit

Permalink
Merge branch 'main' into guard_pos_embed_param
Browse files Browse the repository at this point in the history
  • Loading branch information
timmoon10 committed Apr 9, 2024
2 parents 72f7fd2 + c573826 commit d07d76a
Show file tree
Hide file tree
Showing 16 changed files with 350 additions and 111 deletions.
43 changes: 43 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,53 @@ pipeline {
model.unet_config.use_flash_attention=False \
model.unet_config.attention_resolutions=[1] \
model.unet_config.channel_mult=[1] \
model.ddp_overlap=False \
"
sh "rm -rf /home/TestData/multimodal/stable_diffusion_train"
}
}
stage('L2: Multimodal Stable Diffusion Train with Cuda Graph') {
when {
anyOf {
branch 'main'
changeRequest target: 'main'
}
}
failFast true
steps {
sh "rm -rf /home/TestData/multimodal/stable_diffusion_train_with_cuda_graphs"
sh "python examples/multimodal/text_to_image/stable_diffusion/sd_train.py \
trainer.precision=16 \
trainer.num_nodes=1 \
trainer.devices=1 \
++exp_manager.max_time_per_run=00:00:03:00 \
exp_manager.exp_dir=/home/TestData/multimodal/stable_diffusion_train_with_cuda_graph \
trainer.max_steps=20 \
model.micro_batch_size=1 \
model.global_batch_size=1 \
model.data.synthetic_data=True \
model.first_stage_key=images_moments \
model.cond_stage_key=clip_encoded \
model.optim.name=megatron_fused_adam \
+model.optim.capturable=True \
exp_manager.ema.enable=False \
model.cond_stage_config._target_=nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder \
++model.cond_stage_config.version=openai/clip-vit-large-patch14 \
++model.cond_stage_config.max_length=77 \
model.inductor=False \
~model.cond_stage_config.restore_from_path \
~model.cond_stage_config.freeze \
~model.cond_stage_config.layer \
model.first_stage_config.from_pretrained=null \
model.ddp_overlap=False \
model.capture_cudagraph_iters=15 \
model.unet_config.use_flash_attention=False \
model.unet_config.attention_resolutions=[1] \
model.unet_config.channel_mult=[1] \
"
sh "rm -rf /home/TestData/multimodal/stable_diffusion_train_with_cuda_graphs"
}
}
// stage('L2: Multimodal ControlNet Train') {
// when {
// anyOf {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ model:
use_checkpoint: False
legacy: False
use_flash_attention: True
enable_amp_o2_fp16: False
unet_precision: fp32
resblock_gn_groups: 32
lora_network_alpha: null

Expand Down Expand Up @@ -214,4 +214,4 @@ model:
row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal
layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers
weight_tying: False
position_embedding_strategy: null # used only when weight_tying is True
position_embedding_strategy: null # used only when weight_tying is True
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ model:
use_checkpoint: False
legacy: False
use_flash_attention: True
enable_amp_o2_fp16: False
unet_precision: fp32
resblock_gn_groups: 32

first_stage_config:
Expand Down
32 changes: 31 additions & 1 deletion examples/multimodal/text_to_image/stable_diffusion/sd_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.callbacks import CUDAGraphCallback
from nemo.utils.exp_manager import exp_manager


Expand Down Expand Up @@ -56,12 +57,41 @@ def main(cfg) -> None:

torch.backends.cuda.matmul.allow_tf32 = True

trainer = MegatronStableDiffusionTrainerBuilder(cfg).create_trainer()
callbacks = (
None
if cfg.model.capture_cudagraph_iters < 0
else [CUDAGraphCallback(capture_iteration=cfg.model.capture_cudagraph_iters)]
)
trainer = MegatronStableDiffusionTrainerBuilder(cfg).create_trainer(callbacks)

exp_manager(trainer, cfg.exp_manager)

model = MegatronLatentDiffusion(cfg.model, trainer)

if cfg.model.capture_cudagraph_iters >= 0:
# Warmup the model with random data
with torch.cuda.stream(torch.cuda.Stream()):
n, c, h = cfg.model.micro_batch_size, cfg.model.channels, cfg.model.image_size
x = torch.randn((n, c, h, h), dtype=torch.float32, device="cuda")
t = torch.randint(77, (n,), device="cuda")
cc = torch.randn((n, 77, cfg.model.unet_config.context_dim), dtype=torch.float32, device="cuda",)
if cfg.model.precision in [16, '16']:
x = x.type(torch.float16)
cc = cc.type(torch.float16)
autocast_enabled = False
dgrad_dtype = torch.float16
else:
autocast_enabled = True
dgrad_dtype = torch.float16

model = model.cuda()
for _ in range(5):
with torch.autocast(device_type="cuda", enabled=autocast_enabled, dtype=torch.float16):
out = model.model.model.diffusion_model(x, t, context=cc)
grad = torch.randn_like(out, dtype=dgrad_dtype)
out.backward(grad)
model.zero_grad()

if cfg.model.get('peft', None):

peft_cfg_cls = PEFT_CONFIG_MAP[cfg.model.peft.peft_scheme]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,22 @@ def __init__(
self.W = image_W
self.image_key = image_key
self.txt_key = txt_key
assert image_key.endswith('encoded') == txt_key.endswith(
'encoded'
), 'In precached mode, first and second stage key must both end with "encoded"'
self.precached = self.image_key.endswith('encoded')
img_precached = image_key.endswith('encoded') or image_key.endswith('moments')
txt_precached = txt_key.endswith('encoded')
assert (
img_precached == txt_precached
), 'First and second stage keys should enable/disable precache at the same time.'
self.seq_len = seq_len
self.context_dim = context_dim

def __getitem__(self, index):
item = {}
if self.precached:
if self.image_key.endswith('encoded'):
item[self.image_key] = torch.randn(8, self.H // 8, self.W // 8)
item[self.txt_key] = torch.randn(self.seq_len, self.context_dim)
elif self.image_key.endswith('moments'):
item[self.image_key] = torch.randn(1, 8, self.H // 8, self.W // 8)
item[self.txt_key] = torch.randn(self.seq_len, self.context_dim)
else:
item[self.image_key] = torch.randn(self.H, self.W, 3)
item[self.txt_key] = f'This is meaningless fake text No.{index}'
Expand Down Expand Up @@ -174,7 +178,7 @@ def transform_fn(sample):
if data_cfg.get("validation") is not None and data_cfg.validation.get("data_path"):
if data_cfg.get('synthetic_data', False):
H, W = data_cfg.train.augmentations.center_crop_h_w.split(',')
train_data = SDSyntheticDataset(
val_data = SDSyntheticDataset(
int(H),
int(W),
image_key=model_cfg.first_stage_key,
Expand Down Expand Up @@ -212,24 +216,46 @@ def transform_fn(sample):
# latents are of shape ([4, 64, 64])
return latents, text_embed

train_data = WebDatasetCommon(
dataset_cfg=data_cfg,
consumed_samples=consumed_samples,
map_fn=transform_fn,
compose_fn=tuple_to_dict,
is_train=True,
)

val_data = None
if data_cfg.get("validation") is not None and data_cfg.validation.get("data_path"):
val_data = WebDatasetCommon(
if data_cfg.get('synthetic_data', False):
H, W = data_cfg.train.augmentations.center_crop_h_w.split(',')
train_data = SDSyntheticDataset(
int(H),
int(W),
image_key=model_cfg.first_stage_key,
txt_key=model_cfg.cond_stage_key,
context_dim=model_cfg.unet_config.context_dim,
seq_len=77,
)
else:
train_data = WebDatasetCommon(
dataset_cfg=data_cfg,
consumed_samples=consumed_samples,
map_fn=transform_fn,
compose_fn=tuple_to_dict,
is_train=False,
is_train=True,
)

val_data = None
if data_cfg.get("validation") is not None and data_cfg.validation.get("data_path"):
if data_cfg.get('synthetic_data', False):
H, W = data_cfg.train.augmentations.center_crop_h_w.split(',')
val_data = SDSyntheticDataset(
int(H),
int(W),
image_key=model_cfg.first_stage_key,
txt_key=model_cfg.cond_stage_key,
context_dim=model_cfg.unet_config.context_dim,
seq_len=77,
)
else:
val_data = WebDatasetCommon(
dataset_cfg=data_cfg,
consumed_samples=consumed_samples,
map_fn=transform_fn,
compose_fn=tuple_to_dict,
is_train=False,
)

return train_data, val_data


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,6 @@ def init_from_ckpt(
for k in keys:
if k.startswith("cond_stage_model"):
deleted += 1
logging.info("Deleting ignored key {} from state_dict.".format(k))
del sd[k]
logging.info(f"Deleted {deleted} keys from `cond_stage_model` state_dict.")

Expand All @@ -294,7 +293,7 @@ def init_from_ckpt(
if k.startswith("model.diffusion_model"):
deleted += 1
del sd[k]
logging.info(f"Deleted {deleted} keys from `cond_stage_model` state_dict.")
logging.info(f"Deleted {deleted} keys from `model.diffusion_model` state_dict.")

missing, unexpected = (
self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False)
Expand Down Expand Up @@ -1675,18 +1674,21 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
# megatron_amp_O2 is not yet supported in diffusion models
self.megatron_amp_O2 = cfg.get('megatron_amp_O2', False)

if self.cfg.precision in ['16', 16, 'bf16']:
self.model_parallel_config.enable_autocast = False

self.model = self.model_provider_func()

self.conditioning_keys = []

if self.trainer.precision in ['bf16', 'bf16-mixed']:
if self.model.precision in ['bf16', 'bf16-mixed']:
self.autocast_dtype = torch.bfloat16
elif self.trainer.precision in [32, '32', '32-true']:
elif self.model.precision in [32, '32', '32-true']:
self.autocast_dtype = torch.float
elif self.trainer.precision in [16, '16', '16-mixed']:
elif self.model.precision in ['16-mixed', '16', 16]:
self.autocast_dtype = torch.half
else:
raise ValueError('precision must be in ["32-true", "16-mixed", "bf16-mixed"]')
raise ValueError('precision must be in [32, "32", "32-true", "16-mixed", "16", 16, "bf16-mixed", "bf16"]')

self.log_train_loss = bool(int(os.getenv("NEMO_LOG_TRAIN_LOSS", 1)))
self.loss_broadcast_src_rank = None
Expand Down Expand Up @@ -1780,8 +1782,18 @@ def fwd_bwd_step(self, dataloader_iter, forward_only):

return loss_mean, loss_dict

def training_step(self, dataloader_iter):
def training_step(self, batch):
"""
Notice: `training_step` used to have the following signature to support pipeline
parallelism:
def training_step(self, dataloader_iter, batch_idx):
However, full iteration CUDA Graph callback is not compatible with this signature
right now, due to we need to wrap the dataloader to generate static tensor outside
the CUDA Graph. This signature moves `next(dataloader)` into the CUDA Graph
capturing region, thus we disabled it.
Our dataloaders produce a micro-batch and then we fetch
a number of microbatches depending on the global batch size and model parallel size
from the dataloader to produce a list of microbatches.
Expand All @@ -1793,6 +1805,7 @@ def training_step(self, dataloader_iter):
# we zero grads here because we also call backward in the megatron-core fwd/bwd functions
self._optimizer.zero_grad()

dataloader_iter = iter([batch])
loss_mean, loss_dict = self.fwd_bwd_step(dataloader_iter, False)

# when using sequence parallelism, the sequence parallel layernorm grads must be all-reduced
Expand All @@ -1812,6 +1825,8 @@ def training_step(self, dataloader_iter):
# async grad allreduce is not currently implemented for O1/autocasting mixed precision training
# so we all-reduce gradients after the pipeline
self.allreduce_gradients() # @sangkug we think this is causing memory to blow up (hurts perf)
else:
raise ValueError("Either distributed_fused_adam or megatron_amp_O2 needs to be set if ddp_overlap is set")

# for cuda graph with pytorch lightning
# these values will be used outside the capturing range
Expand All @@ -1828,22 +1843,28 @@ def training_step(self, dataloader_iter):
return loss_mean

def non_cuda_graph_capturable(self):
# Moving CUDA metrics to CPU leads to sync, do not show on progress bar
# if CUDA graph is enabled.
show_metric = self.cfg.get("show_prog_bar_metric", True) and (self.cfg.get("capture_cudagraph_iters", -1) < 0)

if self.log_train_loss:
self.log('reduced_train_loss', self.loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1)
self.log('reduced_train_loss', self.loss_mean, prog_bar=show_metric, rank_zero_only=True, batch_size=1)

if self.cfg.precision in [16, '16', '16-mixed']:
loss_scale = self.trainer.precision_plugin.scaler._scale
if loss_scale is not None:
self.log('loss_scale', loss_scale, batch_size=1)

self.log_dict(self.loss_dict, prog_bar=False, logger=True, on_step=True, rank_zero_only=True, batch_size=1)
self.log_dict(
self.loss_dict, prog_bar=show_metric, logger=True, on_step=True, rank_zero_only=True, batch_size=1
)
lr = self._optimizer.param_groups[0]['lr']
self.log('lr', lr, prog_bar=True, rank_zero_only=True, batch_size=1)
self.log('global_step', self.trainer.global_step + 1, prog_bar=True, rank_zero_only=True, batch_size=1)
self.log('lr', lr, prog_bar=show_metric, rank_zero_only=True, batch_size=1)
self.log('global_step', self.trainer.global_step + 1, prog_bar=show_metric, rank_zero_only=True, batch_size=1)
self.log(
'consumed_samples',
self.compute_consumed_samples(self.trainer.global_step + 1 - self.init_global_step),
prog_bar=True,
prog_bar=show_metric,
rank_zero_only=True,
batch_size=1,
)
Expand Down Expand Up @@ -1902,7 +1923,7 @@ def process_batch(batch):
return [x, *c_list]

def fwd_output_and_loss_func(dataloader_iter, model):
batch, _, _ = next(dataloader_iter)
batch = next(dataloader_iter)
batch = process_batch(batch)
batch = [x.cuda(non_blocking=True) for x in batch]
if len(self.conditioning_keys) == 0:
Expand Down Expand Up @@ -1991,7 +2012,7 @@ def build_train_valid_test_datasets(self):
raise ValueError("limit_val_batches must be an integer or float less than or equal to 1.0.")

if self.cfg.first_stage_key.endswith("encoded") or self.cfg.first_stage_key.endswith("moments"):
if self.cfg.cond_stage_key.endswith("precached_clip"):
if self.cfg.cond_stage_key.endswith("clip_encoded"):
self._train_ds, self._validation_ds = build_train_valid_precached_clip_datasets(
model_cfg=self.cfg, consumed_samples=self.compute_consumed_samples(0),
)
Expand Down Expand Up @@ -2020,7 +2041,7 @@ def setup_training_data(self, cfg):
logging.info(
f'Setting up train dataloader with len(len(self._train_ds)): {len(self._train_ds)} and consumed samples: {consumed_samples}'
)
if self.cfg.cond_stage_key.endswith("precached_clip"):
if self.cfg.cond_stage_key.endswith("clip_encoded"):
collate_fn = get_collate_fn(
first_stage_key=self.cfg.first_stage_key, cond_stage_key=self.cfg.cond_stage_key,
)
Expand Down
Loading

0 comments on commit d07d76a

Please sign in to comment.