diff --git a/.gitignore b/.gitignore index 1ff2a92cac64..1aa5ef00de5e 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ *.pkl #*.ipynb output +output_2048 result *.pt tests/data/asr @@ -179,3 +180,4 @@ examples/neural_graphs/*.yml .hydra/ nemo_experiments/ +slurm*.out diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml index dff963590864..da03a1de96cf 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml +++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml @@ -17,7 +17,6 @@ trainer: enable_model_summary: True limit_val_batches: 0 - exp_manager: exp_dir: null name: ${name} diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_base.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_base.yaml index c536bae15926..7e83093eb780 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_base.yaml +++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_base.yaml @@ -58,8 +58,6 @@ model: lossconfig: target: torch.nn.Identity - - conditioner_config: _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.GeneralConditioner emb_models: diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_base_train.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_base_train.yaml index 7aa765db2e5f..aa1d2782d15b 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_base_train.yaml +++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_base_train.yaml @@ -125,7 +125,6 @@ model: target: torch.nn.Identity - conditioner_config: _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.GeneralConditioner emb_models: diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_infer.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_infer.yaml index eb1f6d7ccb8e..632f1634af50 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_infer.yaml +++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_infer.yaml @@ -31,9 +31,9 @@ infer: sampling: base: sampler: EulerEDMSampler - width: 256 - height: 256 - steps: 40 + width: 512 + height: 512 + steps: 50 discretization: "LegacyDDPMDiscretization" guider: "VanillaCFG" thresholder: "None" @@ -48,8 +48,8 @@ sampling: s_noise: 1.0 eta: 1.0 order: 4 - orig_width: 1024 - orig_height: 1024 + orig_width: 512 + orig_height: 512 crop_coords_top: 0 crop_coords_left: 0 aesthetic_score: 5.0 diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_infer_v2.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_infer_v2.yaml new file mode 100644 index 000000000000..9dc838dcc5c5 --- /dev/null +++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_infer_v2.yaml @@ -0,0 +1,189 @@ +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + precision: 32 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 # PTL default. In practice, max_steps will be reached first. + max_steps: -1 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: True + limit_val_batches: 0 + + +infer: + num_samples_per_batch: 1 + num_samples: 4 + prompt: + - "A professional photograph of an astronaut riding a pig" + - 'A photo of a Shiba Inu dog with a backpack riding a bike. It is wearing sunglasses and a beach hat.' + - 'A cute corgi lives in a house made out of sushi.' + - 'A high contrast portrait of a very happy fuzzy panda dressed as a chef in a high end kitchen making dough. There is a painting of flowers on the wall behind him.' + - 'A brain riding a rocketship heading towards the moon.' + negative_prompt: "" + seed: 123 + + +sampling: + base: + sampler: EulerEDMSampler + width: 512 + height: 512 + steps: 50 + discretization: "LegacyDDPMDiscretization" + guider: "VanillaCFG" + thresholder: "None" + scale: 5.0 + img2img_strength: 1.0 + sigma_min: 0.0292 + sigma_max: 14.6146 + rho: 3.0 + s_churn: 0.0 + s_tmin: 0.0 + s_tmax: 999.0 + s_noise: 1.0 + eta: 1.0 + order: 4 + orig_width: 512 + orig_height: 512 + crop_coords_top: 0 + crop_coords_left: 0 + aesthetic_score: 5.0 + negative_aesthetic_score: 5.0 + +# model: +# is_legacy: False + +use_refiner: False +use_fp16: False # use fp16 model weights +out_path: ./output + +base_model_config: /opt/NeMo/examples/multimodal/generative/stable_diffusion/conf/sd_xl_base.yaml +refiner_config: /opt/NeMo/examples/multimodal/generative/stable_diffusion/conf/sd_xl_refiner.yaml + +model: + scale_factor: 0.13025 + disable_first_stage_autocast: True + is_legacy: False + restore_from_path: "" + + fsdp: False + fsdp_set_buffer_dtype: null + fsdp_sharding_strategy: 'full' + use_cpu_initialization: True + # hidden_size: 4 + # pipeline_model_parallel_size: 4 + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.0 + betas: + - 0.9 + - 0.999 + sched: + name: WarmupHoldPolicy + warmup_steps: 10 + hold_steps: 10000000000000 # Incredibly large value to hold the lr as constant + + denoiser_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.denoiser.DiscreteDenoiser + num_idx: 1000 + + weighting_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.denoiser_scaling.EpsScaling + discretization_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.discretizer.LegacyDDPMDiscretization + + unet_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.openaimodel.UNetModel + from_pretrained: /opt/nemo-aligner/checkpoints/sdxl/unet_nemo.ckpt + from_NeMo: True + adm_in_channels: 2816 + num_classes: sequential + use_checkpoint: False + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4 ] + num_head_channels: 64 + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: [ 1, 2, 10 ] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16 + context_dim: 2048 + image_size: 64 # unused +# spatial_transformer_attn_type: softmax #note: only default softmax is supported now + legacy: False + use_flash_attention: False + + first_stage_config: + # _target_: nemo.collections.multimodal.models.stable_diffusion.ldm.autoencoder.AutoencoderKLInferenceWrapper + _target_: nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKLInferenceWrapper + from_pretrained: /opt/nemo-aligner/checkpoints/sdxl/vae_nemo.ckpt + from_NeMo: True + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 4, 4 ] + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + conditioner_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.GeneralConditioner + emb_models: + # crossattn cond + - is_trainable: False + input_key: txt + emb_model: + _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder + layer: hidden + layer_idx: 11 + # crossattn and vector cond + - is_trainable: False + input_key: txt + emb_model: + _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenOpenCLIPEmbedder2 + arch: ViT-bigG-14 + version: laion2b_s39b_b160k + freeze: True + layer: penultimate + always_return_pooled: True + legacy: False + # vector cond + - is_trainable: False + input_key: original_size_as_tuple + emb_model: + _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.ConcatTimestepEmbedderND + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: crop_coords_top_left + emb_model: + _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.ConcatTimestepEmbedderND + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: target_size_as_tuple + emb_model: + _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.ConcatTimestepEmbedderND + outdim: 256 # multiplied by two + diff --git a/examples/multimodal/text_to_image/stable_diffusion/sd_train.py b/examples/multimodal/text_to_image/stable_diffusion/sd_train.py index 968d9bec2884..7e151699b38c 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/sd_train.py +++ b/examples/multimodal/text_to_image/stable_diffusion/sd_train.py @@ -74,7 +74,11 @@ def main(cfg) -> None: n, c, h = cfg.model.micro_batch_size, cfg.model.channels, cfg.model.image_size x = torch.randn((n, c, h, h), dtype=torch.float32, device="cuda") t = torch.randint(77, (n,), device="cuda") - cc = torch.randn((n, 77, cfg.model.unet_config.context_dim), dtype=torch.float32, device="cuda",) + cc = torch.randn( + (n, 77, cfg.model.unet_config.context_dim), + dtype=torch.float32, + device="cuda", + ) if cfg.model.precision in [16, '16']: x = x.type(torch.float16) cc = cc.type(torch.float16) @@ -93,9 +97,7 @@ def main(cfg) -> None: model.zero_grad() if cfg.model.get('peft', None): - peft_cfg_cls = PEFT_CONFIG_MAP[cfg.model.peft.peft_scheme] - if cfg.model.peft.restore_from_path is not None: # initialize peft weights from a checkpoint instead of randomly # This is not the same as resume training because optimizer states are not restored. diff --git a/examples/multimodal/text_to_image/stable_diffusion/sd_xl_infer.py b/examples/multimodal/text_to_image/stable_diffusion/sd_xl_infer.py index 8d18be517c69..981e83ec95c4 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/sd_xl_infer.py +++ b/examples/multimodal/text_to_image/stable_diffusion/sd_xl_infer.py @@ -26,32 +26,44 @@ def model_cfg_modifier(model_cfg): model_cfg.precision = cfg.trainer.precision model_cfg.ckpt_path = None model_cfg.inductor = False - model_cfg.unet_config.from_pretrained = None - model_cfg.first_stage_config.from_pretrained = None + model_cfg.unet_config.from_pretrained = "/opt/nemo-aligner/checkpoints/sdxl/unet_nemo.ckpt" + model_cfg.unet_config.from_NeMo = True + model_cfg.first_stage_config.from_pretrained = "/opt/nemo-aligner/checkpoints/sdxl/vae_nemo.ckpt" + model_cfg.first_stage_config.from_NeMo = True model_cfg.first_stage_config._target_ = 'nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKLInferenceWrapper' - model_cfg.fsdp = False + # model_cfg.fsdp = True torch.backends.cuda.matmul.allow_tf32 = True trainer, megatron_diffusion_model = setup_trainer_and_model_for_inference( model_provider=MegatronDiffusionEngine, cfg=cfg, model_cfg_modifier=model_cfg_modifier ) + ### Manually configure sharded model + # model = megatron_diffusion_model + # model = trainer.strategy._setup_model(model) + # model = model.cuda(torch.cuda.current_device()) + # get the diffusion part only model = megatron_diffusion_model.model model.cuda().eval() - base = SamplingPipeline(model, use_fp16=cfg.use_fp16, is_legacy=cfg.model.is_legacy) - use_refiner = cfg.get('use_refiner', False) - for i, prompt in enumerate(cfg.infer.prompt): - samples = base.text_to_image( - params=cfg.sampling.base, - prompt=[prompt], - negative_prompt=cfg.infer.negative_prompt, - samples=cfg.infer.num_samples, - return_latents=True if use_refiner else False, - seed=int(cfg.infer.seed + i * 100), - ) - - perform_save_locally(cfg.out_path, samples) + with torch.no_grad(): + base = SamplingPipeline(model, use_fp16=cfg.use_fp16, is_legacy=cfg.model.is_legacy) + use_refiner = cfg.get('use_refiner', False) + num_samples_per_batch = cfg.infer.get('num_samples_per_batch', cfg.infer.num_samples) + num_batches = cfg.infer.num_samples // num_samples_per_batch + + for i, prompt in enumerate(cfg.infer.prompt): + for batchid in range(num_batches): + samples = base.text_to_image( + params=cfg.sampling.base, + prompt=[prompt], + negative_prompt=cfg.infer.negative_prompt, + samples=num_samples_per_batch, + return_latents=True if use_refiner else False, + seed=int(cfg.infer.seed + i * 100 + batchid * 200), + ) + # samples=cfg.infer.num_samples, + perform_save_locally(cfg.out_path, samples) if __name__ == "__main__": diff --git a/examples/multimodal/text_to_image/stable_diffusion/sd_xl_train.py b/examples/multimodal/text_to_image/stable_diffusion/sd_xl_train.py index a91beca93761..44412aee0d14 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/sd_xl_train.py +++ b/examples/multimodal/text_to_image/stable_diffusion/sd_xl_train.py @@ -41,7 +41,10 @@ def _training_strategy(self) -> NLPDDPStrategy: _IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive) if _IS_INTERACTIVE and self.cfg.trainer.devices == 1: logging.info("Detected interactive environment, using NLPDDPStrategyNotebook") - return NLPDDPStrategyNotebook(no_ddp_communication_hook=True, find_unused_parameters=False,) + return NLPDDPStrategyNotebook( + no_ddp_communication_hook=True, + find_unused_parameters=False, + ) if self.cfg.model.get('fsdp', False): assert ( @@ -81,9 +84,7 @@ def main(cfg) -> None: model = MegatronDiffusionEngine(cfg.model, trainer) if cfg.model.get('peft', None): - peft_cfg_cls = PEFT_CONFIG_MAP[cfg.model.peft.peft_scheme] - if cfg.model.peft.restore_from_path is not None: # initialize peft weights from a checkpoint instead of randomly # This is not the same as resume training because optimizer states are not restored. diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_engine.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_engine.py index efc1550113a0..755588202ef0 100644 --- a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_engine.py +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_engine.py @@ -119,7 +119,9 @@ def __init__(self, cfg, model_parallel_config): self._init_first_stage(first_stage_config) self.model_type = None - self.rng = torch.Generator(device=torch.cuda.current_device(),) + self.rng = torch.Generator( + device=torch.cuda.current_device(), + ) self.use_ema = False # TODO use_ema need to switch to NeMo style if self.use_ema: @@ -158,6 +160,13 @@ def decode_first_stage(self, z): out = self.first_stage_model.decode(z) return out + # same as above but differentiable + def differentiable_decode_first_stage(self, z): + z = 1.0 / self.scale_factor * z + with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): + out = self.first_stage_model.decode(z) + return out + @torch.no_grad() def encode_first_stage(self, x): with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): @@ -185,7 +194,12 @@ def training_step(self, batch, batch_idx): self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False) self.log( - "global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False, + "global_step", + self.global_step, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=False, ) if self.scheduler_config is not None: @@ -231,7 +245,11 @@ def configure_optimizers(self): scheduler = DiffusionEngine.from_config_dict(self.scheduler_config) print("Setting up LambdaLR scheduler...") scheduler = [ - {"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), "interval": "step", "frequency": 1,} + { + "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), + "interval": "step", + "frequency": 1, + } ] return [opt], scheduler return opt @@ -291,7 +309,14 @@ def set_input_tensor(self, input_tensor): pass @torch.no_grad() - def log_images(self, batch: Dict, N: int = 8, sample: bool = True, ucg_keys: List[str] = None, **kwargs,) -> Dict: + def log_images( + self, + batch: Dict, + N: int = 8, + sample: bool = True, + ucg_keys: List[str] = None, + **kwargs, + ) -> Dict: conditioner_input_keys = [e.input_key for e in self.conditioner.embedders] if ucg_keys: assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), ( @@ -305,7 +330,8 @@ def log_images(self, batch: Dict, N: int = 8, sample: bool = True, ucg_keys: Lis x = self.get_input(batch) c, uc = self.conditioner.get_unconditional_conditioning( - batch, force_uc_zero_embeddings=ucg_keys if len(self.conditioner.embedders) > 0 else [], + batch, + force_uc_zero_embeddings=ucg_keys if len(self.conditioner.embedders) > 0 else [], ) sampling_kwargs = {} @@ -400,7 +426,10 @@ def fwd_bwd_step(self, dataloader_iter, forward_only): # handle asynchronous grad reduction no_sync_func = None if not forward_only and self.with_distributed_adam: - no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,) + no_sync_func = partial( + self._optimizer.no_sync, + greedy_grad_copy=self.megatron_amp_O2, + ) # pipeline schedules will get these from self.model.config for module in self.get_module_list(): @@ -438,12 +467,12 @@ def fwd_bwd_step(self, dataloader_iter, forward_only): def training_step(self, dataloader_iter): """ - Our dataloaders produce a micro-batch and then we fetch - a number of microbatches depending on the global batch size and model parallel size - from the dataloader to produce a list of microbatches. - Batch should be a list of microbatches and those microbatches should on CPU. - Microbatches are then moved to GPU during the pipeline. - The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + Batch should be a list of microbatches and those microbatches should on CPU. + Microbatches are then moved to GPU during the pipeline. + The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. """ self._optimizer.zero_grad() @@ -491,20 +520,20 @@ def training_step(self, dataloader_iter): return loss_mean def backward(self, *args, **kwargs): - """ LightningModule hook to do backward. - We want this to do nothing since we run backward in the fwd/bwd functions from apex. - No need to call it here. + """LightningModule hook to do backward. + We want this to do nothing since we run backward in the fwd/bwd functions from apex. + No need to call it here. """ pass def optimizer_zero_grad(self, *args, **kwargs): - """ LightningModule hook to zero grad. - We want this to do nothing as we are zeroing grads during the training_step. + """LightningModule hook to zero grad. + We want this to do nothing as we are zeroing grads during the training_step. """ pass def _append_sequence_parallel_module_grads(self, module, grads): - """ Helper method for allreduce_sequence_parallel_gradients""" + """Helper method for allreduce_sequence_parallel_gradients""" for param in module.parameters(): sequence_parallel_param = getattr(param, 'sequence_parallel', False) @@ -517,12 +546,13 @@ def _append_sequence_parallel_module_grads(self, module, grads): def get_forward_output_and_loss_func(self): def process_batch(batch): - """ Prepares the global batch for apex fwd/bwd functions. - Global batch is a list of micro batches. + """Prepares the global batch for apex fwd/bwd functions. + Global batch is a list of micro batches. """ # SD has more dedicated structure for encoding, so we enable autocasting here as well with torch.cuda.amp.autocast( - self.autocast_dtype in (torch.half, torch.bfloat16), dtype=self.autocast_dtype, + self.autocast_dtype in (torch.half, torch.bfloat16), + dtype=self.autocast_dtype, ): if self.model.precache_mode == 'both': x = batch[self.model.input_key].to(torch.cuda.current_device()) @@ -565,7 +595,7 @@ def validation_step(self, dataloader_iter, batch_idx): return loss def setup(self, stage=None): - """ PTL hook that is executed after DDP spawns. + """PTL hook that is executed after DDP spawns. We setup datasets here as megatron datasets require DDP to instantiate. See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. Args: @@ -678,20 +708,23 @@ def setup_test_data(self, cfg): f'Setting up test dataloader with len(len(self._test_ds)): {len(self._test_ds)} and consumed samples: {consumed_samples}' ) self._test_dl = torch.utils.data.DataLoader( - self._test_ds, batch_size=self._micro_batch_size, num_workers=cfg.num_workers, pin_memory=True, + self._test_ds, + batch_size=self._micro_batch_size, + num_workers=cfg.num_workers, + pin_memory=True, ) def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: - """ PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device - When using pipeline parallelism, we need the global batch to remain on the CPU, - since the memory overhead will be too high when using a large number of microbatches. - Microbatches are transferred from CPU to GPU inside the pipeline. + """PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device + When using pipeline parallelism, we need the global batch to remain on the CPU, + since the memory overhead will be too high when using a large number of microbatches. + Microbatches are transferred from CPU to GPU inside the pipeline. """ return batch def _validate_trainer(self): - """ Certain trainer configurations can break training. - Here we try to catch them and raise an error. + """Certain trainer configurations can break training. + Here we try to catch them and raise an error. """ if self.trainer.accumulate_grad_batches > 1: raise ValueError( diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py index 6bd47a78fbcf..d79d85c2e026 100644 --- a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py @@ -16,6 +16,7 @@ import pytorch_lightning as pl import torch import torch.nn.functional as F +from nemo.utils import logging try: from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer @@ -316,6 +317,7 @@ def __init__( ignore_keys=[], image_key="image", colorize_nlabels=None, + from_NeMo=False, monitor=None, from_pretrained: str = None, ): @@ -337,6 +339,7 @@ def __init__( self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) if from_pretrained is not None: + logging.info(f"Attempting to load vae weights from {from_pretrained}") if from_pretrained.endswith('safetensors'): from safetensors.torch import load_file as load_safetensors @@ -345,7 +348,7 @@ def __init__( state_dict = torch.load(from_pretrained) if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] - missing_key, unexpected_key, _, _ = self._load_pretrained_model(state_dict) + missing_key, unexpected_key, _, _ = self._load_pretrained_model(state_dict, from_NeMo=from_NeMo) if len(missing_key) > 0: print( f'{self.__class__.__name__}: Following keys are missing during loading VAE weights, which may lead to compromised image quality for a resumed training. Please check the checkpoint you provided.' @@ -395,8 +398,9 @@ def _state_key_mapping(self, state_dict: dict): res_dict[key_] = val_ return res_dict - def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False): - state_dict = self._state_key_mapping(state_dict) + def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from_NeMo=False): + if not from_NeMo: + state_dict = self._state_key_mapping(state_dict) model_state_dict = self.state_dict() loaded_keys = [k for k in state_dict.keys()] expected_keys = list(model_state_dict.keys()) @@ -405,7 +409,10 @@ def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False): unexpected_keys = list(set(loaded_keys) - set(expected_keys)) def _find_mismatched_keys( - state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes, + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, ): mismatched_keys = [] if ignore_mismatched_sizes: @@ -440,7 +447,10 @@ def _find_mismatched_keys( if state_dict is not None: # Whole checkpoint mismatched_keys = _find_mismatched_keys( - state_dict, model_state_dict, original_loaded_keys, ignore_mismatched_sizes, + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, ) error_msgs = self._load_state_dict_into_model(state_dict) return missing_keys, unexpected_keys, mismatched_keys, error_msgs diff --git a/nemo/collections/multimodal/modules/stable_diffusion/attention.py b/nemo/collections/multimodal/modules/stable_diffusion/attention.py index 2eeed97db781..e748bcbf93a0 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/attention.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/attention.py @@ -227,6 +227,10 @@ def __init__(self, in_features, out_features, bias=True, lora_network_alpha=None def forward(self, x): mixed_x = super().forward(x) if self.is_adapter_available(): + # return this output if lora is not enabled + cfg = self.get_adapter_cfg(AdapterName.PARALLEL_LINEAR_ADAPTER) + if not cfg['enabled']: + return mixed_x lora_linear_adapter = self.get_adapter_module(AdapterName.PARALLEL_LINEAR_ADAPTER) lora_mixed_x = lora_linear_adapter(x) # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/denoiser.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/denoiser.py index df1f27449bd1..a358bb08f92d 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/denoiser.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/denoiser.py @@ -33,13 +33,18 @@ def possibly_quantize_c_noise(self, c_noise): def w(self, sigma): return self.weighting(sigma) - def __call__(self, network, input, sigma, cond): + def __call__(self, network, input, sigma, cond, return_noise=False): sigma = self.possibly_quantize_sigma(sigma) sigma_shape = sigma.shape sigma = append_dims(sigma, input.ndim) c_skip, c_out, c_in, c_noise = self.scaling(sigma) c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) - return network(input * c_in, c_noise, cond) * c_out + input * c_skip + # predict noise from network + noise_pred = network(input * c_in, c_noise, cond) + denoised = noise_pred * c_out + input * c_skip + if return_noise: + return denoised, noise_pred + return denoised class DiscreteDenoiser(Denoiser): diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py index 7f8b2fb20bff..eb449c5406b9 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py @@ -789,6 +789,7 @@ def __init__( self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) + if level != len(channel_mult) - 1: out_ch = ch self.input_blocks.append( @@ -954,6 +955,7 @@ def __init__( ) if from_pretrained is not None: + logging.info(f"Attempting to load pretrained unet from {from_pretrained}") if from_pretrained.endswith('safetensors'): from safetensors.torch import load_file as load_safetensors @@ -1021,6 +1023,16 @@ def _input_blocks_mapping(self, input_dict): .replace('conv2', 'out_layers.3') .replace('conv_shortcut', 'skip_connection') ) + ## Rohit: I've changed this to make sure it is compatible + # post_fix = ( + # key_[25:] + # .replace('time_emb_proj', 'emb_layers.1') + # .replace('norm1', 'in_layers.0') + # .replace('norm2', 'out_layers.0') + # .replace('conv1', 'in_layers.1') + # .replace('conv2', 'out_layers.2') + # .replace('conv_shortcut', 'skip_connection') + # ) res_dict["input_blocks." + str(target_id) + '.0.' + post_fix] = value_ elif "attentions" in key_: id_1 = int(key_[26]) @@ -1168,7 +1180,7 @@ def te_fp8_key_mapping(self, unet_dict): return new_state_dict def _state_key_mapping(self, state_dict: dict): - + # state_dict is a HF model res_dict = {} input_dict = {} mid_dict = {} diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/sampling.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/sampling.py index c636ffec345d..bfae8790eeb2 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/sampling.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/sampling.py @@ -47,7 +47,12 @@ def __init__( ): self.num_steps = num_steps self.discretization = instantiate_from_config(discretization_config) - self.guider = instantiate_from_config(default(guider_config, DEFAULT_GUIDER,)) + self.guider = instantiate_from_config( + default( + guider_config, + DEFAULT_GUIDER, + ) + ) self.verbose = verbose self.device = device @@ -93,35 +98,50 @@ def euler_step(self, x, d, dt): class EDMSampler(SingleStepDiffusionSampler): def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs): super().__init__(*args, **kwargs) - self.s_churn = s_churn self.s_tmin = s_tmin self.s_tmax = s_tmax self.s_noise = s_noise - def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0): + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0, return_noise=False): + # x is actually \bar{x} as in the DDIM paper sigma_hat = sigma * (gamma + 1.0) if gamma > 0: eps = torch.randn_like(x) * self.s_noise - x = x + eps * append_dims(sigma_hat ** 2 - sigma ** 2, x.ndim) ** 0.5 + x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 denoised = self.denoise(x, denoiser, sigma_hat, cond, uc) + # this is the noise (e_t) d = to_d(x, sigma_hat, denoised) dt = append_dims(next_sigma - sigma_hat, x.ndim) - euler_step = self.euler_step(x, d, dt) + euler_step = self.euler_step(x, d, dt) # this is x_{t-\delta{t}} x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc) + if return_noise: + return x, d return x + def get_gamma(self, sigmas, num_sigmas, index): + gamma = ( + min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[index] <= self.s_tmax else 0.0 + ) + return gamma + def __call__(self, denoiser, x, cond, uc=None, num_steps=None): + # prepare_sampling_loop converts x into \bar{x} = x / \sqrt{\tilde{\alpha_t}} x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) for i in self.get_sigma_gen(num_sigmas): - gamma = ( - min(self.s_churn / (num_sigmas - 1), 2 ** 0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 + gamma = self.get_gamma(sigmas, num_sigmas, i) + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + gamma, ) - x = self.sampler_step(s_in * sigmas[i], s_in * sigmas[i + 1], denoiser, x, cond, uc, gamma,) - return x @@ -151,14 +171,24 @@ def __call__(self, denoiser, x, cond, uc=None, num_steps=None): x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) for i in self.get_sigma_gen(num_sigmas): - x = self.sampler_step(s_in * sigmas[i], s_in * sigmas[i + 1], denoiser, x, cond, uc,) + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + ) return x class LinearMultistepSampler(BaseDiffusionSampler): def __init__( - self, order=4, *args, **kwargs, + self, + order=4, + *args, + **kwargs, ): super().__init__(*args, **kwargs) @@ -276,7 +306,15 @@ def get_mult(self, h, r, t, t_next, previous_sigma): return mult1, mult2 def sampler_step( - self, old_denoised, previous_sigma, sigma, next_sigma, denoiser, x, cond, uc=None, + self, + old_denoised, + previous_sigma, + sigma, + next_sigma, + denoiser, + x, + cond, + uc=None, ): denoised = self.denoise(x, denoiser, sigma, cond, uc) diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/wrappers.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/wrappers.py index 0d465c1275c6..24e2124e6f83 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/wrappers.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/wrappers.py @@ -37,6 +37,11 @@ class OpenAIWrapper(IdentityWrapper): def forward(self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs) -> torch.Tensor: if c.get("concat", None): x = torch.cat((x, c.get("concat")), dim=1) + return self.diffusion_model( - x, timesteps=t, context=c.get("crossattn", None), y=c.get("vector", None), **kwargs, + x, + timesteps=t, + context=c.get("crossattn", None), + y=c.get("vector", None), + **kwargs, ) diff --git a/nemo/collections/multimodal/parts/utils.py b/nemo/collections/multimodal/parts/utils.py index 7eb72b38d0f0..5a01e8702a9e 100644 --- a/nemo/collections/multimodal/parts/utils.py +++ b/nemo/collections/multimodal/parts/utils.py @@ -23,11 +23,11 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins.environments import TorchElasticEnvironment from transformers import CLIPImageProcessor, SiglipImageProcessor -from nemo.collections.multimodal.data.clip.augmentations.augmentations import image_transform +from nemo.collections.multimodal.data.clip.augmentations.augmentations import image_transform from nemo.collections.multimodal.data.neva.neva_dataset import process_image from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel -from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPFSDPStrategy, NLPSaveRestoreConnector from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision from nemo.utils import AppState, logging @@ -276,10 +276,23 @@ def setup_trainer_and_model_for_inference( # Use the NLPDDPStrategy for the distributed data parallel strategy. # We don't use DDP for async grad allreduce and don't find unused parameters. - strategy = NLPDDPStrategy( - no_ddp_communication_hook=True, - find_unused_parameters=False, - ) + if not cfg.model.get('fsdp', False): + logging.info("FSDP is False, using DDP strategy.") + strategy = NLPDDPStrategy( + no_ddp_communication_hook=True, + find_unused_parameters=False, + ) + else: + logging.info("Using FSDP strategy.") + strategy = NLPFSDPStrategy( + limit_all_gathers=cfg.model.get('fsdp_limit_all_gathers', True), + sharding_strategy=cfg.model.get('fsdp_sharding_strategy', 'full'), + cpu_offload=cfg.model.get('fsdp_cpu_offload', True), + grad_reduce_dtype=cfg.model.get('fsdp_grad_reduce_dtype', 32), + precision=cfg.trainer.precision, + # use_orig_params=cfg.model.inductor, + set_buffer_dtype=cfg.get('fsdp_set_buffer_dtype', None), + ) # Set up the trainer with the specified plugins and strategy. trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer) @@ -323,7 +336,9 @@ def setup_trainer_and_model_for_inference( ) else: - raise ValueError(f"Unrecognized checkpoint type: {cfg.model.restore_from_path}") + # load a model from scratch + logging.warning("Loading a model from scratch for inference. Tread carefully.") + model = model_provider(cfg=cfg.model, trainer=trainer) # initialize apex DDP strategy def dummy(): diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 4ded9a42db4f..e1641a81c0dc 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -1271,6 +1271,8 @@ def find_frozen_submodules(model): # TODO: Currently the main parameter data type is kept in fp32 (when O2=False). This needs to be # extended to support lower precision main parameters. frozen_submodule_names, frozen_submodules = find_frozen_submodules(self.model) + for submodule in frozen_submodule_names: + logging.debug(f"Ignoring state {submodule} in FSDP.") self.trainer.strategy.kwargs['ignored_states'] = frozen_submodules # FSDP requires uniform status of require_grads # Diffusion models like SD has frozen parts and needs to be added to 'ignored_states' from sharding for FSDP to work diff --git a/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py b/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py index 45f4af3cfbf3..2bacaf52e3f8 100644 --- a/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py +++ b/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py @@ -161,7 +161,6 @@ def _get_layers_from_model(self, model): def _check_and_add_peft_cfg(self, peft_cfg): layer_selection = peft_cfg.layer_selection - assert not self.use_mcore_gpt or hasattr( peft_cfg, 'name_key_to_mcore_mixins' ), f"{peft_cfg.__class__.__name__} is not supported in megatron core mode yet." diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index e251690831cb..b003e310baeb 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -701,6 +701,7 @@ def __init__( nccl_communicator_config_path: Optional[str] = None, sharp: bool = False, set_buffer_dtype: Optional[str] = None, + extra_fsdp_wrap_module: Optional[set] = None, **kwargs: Union[Any, Dict[str, Any]], ) -> None: if not HAVE_APEX: @@ -730,6 +731,11 @@ def __init__( ParallelTransformerLayer, BasicTransformerBlock, } + + # if extra wrap modules are provided, use them + if extra_fsdp_wrap_module is not None: + self.fsdp_wrap_module.update(extra_fsdp_wrap_module) + kwargs['auto_wrap_policy'] = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=self.fsdp_wrap_module ) diff --git a/nemo/core/classes/mixins/adapter_mixins.py b/nemo/core/classes/mixins/adapter_mixins.py index 05ac9b429d85..7b5d02c86bf7 100644 --- a/nemo/core/classes/mixins/adapter_mixins.py +++ b/nemo/core/classes/mixins/adapter_mixins.py @@ -391,6 +391,14 @@ def get_adapter_module(self, name: str): return self.adapter_layer[name] if name in self.adapter_layer else None return None + def get_adapter_cfg(self, name: str): + """Same logic as `get_adapter_module` but to get the config""" + _, name = self.resolve_adapter_module_name_(name) + + if hasattr(self, "adapter_cfg"): + return self.adapter_cfg[name] if name in self.adapter_cfg else None + return None + def set_accepted_adapter_types(self, adapter_types: List[Union[type, str]]) -> None: """ The module with this mixin can define a list of adapter names that it will accept. diff --git a/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py b/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py new file mode 100644 index 000000000000..67bc975708d0 --- /dev/null +++ b/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py @@ -0,0 +1,452 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +Conversion script to convert HuggingFace Starcoder2 checkpoints into nemo checkpoint. + Example to run this conversion script: + python convert_hf_starcoder2_to_nemo.py \ + --input_name_or_path \ + --output_path +""" + +from argparse import ArgumentParser + +import numpy as np +import safetensors +import torch +import torch.nn + +from nemo.utils import logging + +intkey = lambda x: int(x) + + +def filter_keys(rule, dict): + keys = list(dict.keys()) + nd = {k: dict[k] for k in keys if rule(k)} + return nd + + +def map_keys(rule, dict): + new = {rule(k): v for k, v in dict.items()} + return new + + +def split_name(name, dots=0): + l = name.split(".") + return ".".join(l[: dots + 1]), ".".join(l[dots + 1 :]) + + +def is_prefix(shortstr, longstr): + # is the first string a prefix of the second one + return longstr == shortstr or longstr.startswith(shortstr + ".") + + +def numdots(str): + return str.count(".") + + +class SegTree: + def __init__(self): + self.nodes = dict() + self.val = None + self.final_val = 0 + self.convert_name = None + + def __len__(self): + return len(self.nodes) + + def is_leaf(self): + return len(self.nodes) == 0 + + def add(self, name, val=0): + prefix, subname = split_name(name) + if subname == '': + self.nodes[name] = SegTree() + self.nodes[name].val = val + return + if self.nodes.get(prefix) is None: + self.nodes[prefix] = SegTree() + self.nodes[prefix].add(subname, val) + + def change(self, name, val): + self.add(name, val) + + def __getitem__(self, name: str): + if hasattr(self, name): + return getattr(self, name) + val = self.nodes.get(name) + if val is None: + # straight lookup failed, do a prefix lookup + keys = list(self.nodes.keys()) + p_flag = [is_prefix(k, name) for k in keys] + if not any(p_flag): + return None + # either more than 1 match (error) or exactly 1 (success) + if np.sum(p_flag) > 1: + print(f"error: multiple matches of key {name} with {keys}") + else: + i = np.where(p_flag)[0][0] + n = numdots(keys[i]) + prefix, substr = split_name(name, n) + return self.nodes[prefix][substr] + return val + + +def model_to_tree(model): + keys = list(model.keys()) + tree = SegTree() + for k in keys: + tree.add(k, "leaf") + return tree + + +def get_args(): + parser = ArgumentParser() + parser.add_argument( + "--input_name_or_path", + type=str, + default=None, + required=True, + help="Path to Huggingface UNet checkpoints", + ) + parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to output .nemo file.") + parser.add_argument("--precision", type=str, default="32", help="Model precision") + parser.add_argument("--model", type=str, default="unet", required=True, choices=['unet', 'vae']) + parser.add_argument("--debug", action='store_true', help="Useful for debugging purposes.") + + args = parser.parse_args() + return args + + +def make_tiny_config(config): + '''dial down the config file to make things tractable''' + # TODO + return config + + +def load_hf_ckpt(in_dir, args): + ckpt = {} + with safetensors.safe_open(in_dir + "/diffusion_pytorch_model.safetensors", framework="pt") as f: + for k in f.keys(): + ckpt[k] = f.get_tensor(k) + return args, ckpt + + +def dup_convert_name_recursive(tree: SegTree, convert_name=None): + '''inside this tree, convert all nodes recursively + optionally, convert the name of the root as given by name (if not None) + ''' + if tree is None: + return + if convert_name is not None: + tree.convert_name = convert_name + # recursively copy the name into convert_name + for k, v in tree.nodes.items(): + dup_convert_name_recursive(v, k) + + +def sanity_check(hf_tree, hf_unet, nemo_unet): + # check if i'm introducing new keys + for hfk, nk in hf_to_nemo_mapping(hf_tree).items(): + if nk not in nemo_unet.keys(): + print(nk) + if hfk not in hf_unet.keys(): + print(hfk) + + +def convert_input_keys(hf_tree: SegTree): + '''map the input blocks of huggingface model''' + # map `conv_in` to first input block + dup_convert_name_recursive(hf_tree['conv_in'], 'input_blocks.0.0') + + # start counting blocks from now on + nemo_inp_blk = 1 + down_blocks = hf_tree['down_blocks'] + down_blocks_keys = sorted(list(down_blocks.nodes.keys()), key=intkey) + for downblockid in down_blocks_keys: + block = down_blocks[str(downblockid)] + # compute number of resnets, attentions, downsamplers in this block + resnets = block.nodes.get('resnets', SegTree()) + attentions = block.nodes.get('attentions', SegTree()) + downsamplers = block.nodes.get('downsamplers', SegTree()) + + if len(attentions) == 0: # no attentions, this is a DownBlock2d + for resid in sorted(list(resnets.nodes.keys()), key=intkey): + resid = str(resid) + resnets[resid].convert_name = f"input_blocks.{nemo_inp_blk}.0" + map_resnet_block(resnets[resid]) + nemo_inp_blk += 1 + elif len(attentions) == len(resnets): + # there are attention blocks here -- each resnet+attention becomes a block + for resid in sorted(list(resnets.nodes.keys()), key=intkey): + resid = str(resid) + resnets[resid].convert_name = f"input_blocks.{nemo_inp_blk}.0" + map_resnet_block(resnets[resid]) + attentions[resid].convert_name = f"input_blocks.{nemo_inp_blk}.1" + map_attention_block(attentions[resid]) + nemo_inp_blk += 1 + else: + logging.warning("number of attention blocks is not the same as resnets - whats going on?") + + # if there is a downsampler, then also append it + if len(downsamplers) > 0: + for k in downsamplers.nodes.keys(): + downsamplers[k].convert_name = f"input_blocks.{nemo_inp_blk}.{k}" + dup_convert_name_recursive(downsamplers[k]['conv'], 'op') + nemo_inp_blk += 1 + + +def clean_convert_names(tree): + tree.convert_name = None + for k, v in tree.nodes.items(): + clean_convert_names(v) + + +def map_attention_block(att_tree: SegTree): + '''this HF tree can either be an AttentionBlock or a DualAttention block + currently assumed AttentionBlock + + ''' + + # TODO (rohit): Add check for dual attention block + def check_att_type(tree): + return "att_block" + + if check_att_type(att_tree) == 'att_block': + dup_convert_name_recursive(att_tree['norm'], 'norm') + dup_convert_name_recursive(att_tree['proj_in'], 'proj_in') + dup_convert_name_recursive(att_tree['proj_out'], 'proj_out') + tblockids = list(att_tree['transformer_blocks'].nodes.keys()) + for t in tblockids: + tblock = att_tree[f'transformer_blocks.{t}'] + tblock.convert_name = f"transformer_blocks.{t}" + dup_convert_name_recursive(tblock['attn1'], 'attn1') + dup_convert_name_recursive(tblock['attn2'], 'attn2') + dup_convert_name_recursive(tblock['norm1'], 'attn1.norm') + dup_convert_name_recursive(tblock['norm2'], 'attn2.norm') + dup_convert_name_recursive(tblock['norm3'], 'ff.net.0') + # map ff module + tblock['ff'].convert_name = "ff" + tblock['ff.net'].convert_name = 'net' + dup_convert_name_recursive(tblock['ff.net.0'], '1') + dup_convert_name_recursive(tblock['ff.net.2'], '3') + else: + logging.warning("failed to identify type of attention block here.") + + +def map_resnet_block(resnet_tree: SegTree): + '''this HF tree is supposed to have all the keys for a resnet''' + dup_convert_name_recursive(resnet_tree.nodes.get('time_emb_proj'), 'emb_layers.1') + dup_convert_name_recursive(resnet_tree['norm1'], 'in_layers.0') + dup_convert_name_recursive(resnet_tree['conv1'], 'in_layers.1') + dup_convert_name_recursive(resnet_tree['norm2'], 'out_layers.0') + dup_convert_name_recursive(resnet_tree['conv2'], 'out_layers.2') + dup_convert_name_recursive(resnet_tree.nodes.get('conv_shortcut'), 'skip_connection') + + +def hf_to_nemo_mapping(tree: SegTree): + mapping = {} + for nodename, subtree in tree.nodes.items(): + convert_name = subtree.convert_name + convert_name = (convert_name + ".") if convert_name is not None else "" + if subtree.is_leaf() and subtree.convert_name is not None: + mapping[nodename] = subtree.convert_name + else: + submapping = hf_to_nemo_mapping(subtree) + for k, v in submapping.items(): + mapping[nodename + "." + k] = convert_name + v + return mapping + + +def convert_cond_keys(tree: SegTree): + # map all conditioning keys + tree['add_embedding'].convert_name = 'label_emb.0' + dup_convert_name_recursive(tree['add_embedding.linear_1'], '0') + dup_convert_name_recursive(tree['add_embedding.linear_2'], '2') + tree['time_embedding'].convert_name = 'time_embed' + dup_convert_name_recursive(tree['time_embedding.linear_1'], '0') + dup_convert_name_recursive(tree['time_embedding.linear_2'], '2') + + +def convert_middle_keys(tree: SegTree): + '''middle block is fixed (resnet -> attention -> resnet)''' + mid = tree['mid_block'] + resnets = mid['resnets'] + attns = mid['attentions'] + mid.convert_name = 'middle_block' + resnets['0'].convert_name = '0' + resnets['1'].convert_name = '2' + attns['0'].convert_name = '1' + map_resnet_block(resnets['0']) + map_resnet_block(resnets['1']) + map_attention_block(attns['0']) + + +def convert_output_keys(hf_tree: SegTree): + '''output keys is similar to input keys''' + nemo_inp_blk = 0 + up_blocks = hf_tree['up_blocks'] + up_blocks_keys = sorted(list(up_blocks.nodes.keys()), key=intkey) + + for downblockid in up_blocks_keys: + block = up_blocks[str(downblockid)] + # compute number of resnets, attentions, downsamplers in this block + resnets = block.nodes.get('resnets', SegTree()) + attentions = block.nodes.get('attentions', SegTree()) + upsamplers = block.nodes.get('upsamplers', SegTree()) + + if len(attentions) == 0: # no attentions, this is a DownBlock2d + for resid in sorted(list(resnets.nodes.keys()), key=intkey): + resid = str(resid) + resnets[resid].convert_name = f"output_blocks.{nemo_inp_blk}.0" + map_resnet_block(resnets[resid]) + nemo_inp_blk += 1 + + elif len(attentions) == len(resnets): + # there are attention blocks here -- each resnet+attention becomes a block + for resid in sorted(list(resnets.nodes.keys()), key=intkey): + resid = str(resid) + resnets[resid].convert_name = f"output_blocks.{nemo_inp_blk}.0" + map_resnet_block(resnets[resid]) + attentions[resid].convert_name = f"output_blocks.{nemo_inp_blk}.1" + map_attention_block(attentions[resid]) + nemo_inp_blk += 1 + else: + logging.warning("number of attention blocks is not the same as resnets - whats going on?") + + # if there is a downsampler, then also append it + if len(upsamplers) > 0: + # for k in upsamplers.nodes.keys(): + nemo_inp_blk -= 1 + upsamplers['0'].convert_name = f"output_blocks.{nemo_inp_blk}.2" + dup_convert_name_recursive(upsamplers['0.conv'], 'conv') + nemo_inp_blk += 1 + + +def convert_finalout_keys(hf_tree: SegTree): + dup_convert_name_recursive(hf_tree['conv_norm_out'], "out.0") + dup_convert_name_recursive(hf_tree['conv_out'], "out.1") + + +def convert_encoder(hf_tree: SegTree): + encoder = hf_tree['encoder'] + encoder.convert_name = 'encoder' + dup_convert_name_recursive(encoder['conv_in'], 'conv_in') + dup_convert_name_recursive(encoder['conv_out'], 'conv_out') + dup_convert_name_recursive(encoder['conv_norm_out'], 'norm_out') + + # each block contains resnets and downsamplers + # there are also optional attention blocks in the down module, but I havent encountered them yet + encoder['down_blocks'].convert_name = 'down' + for downid, downblock in encoder['down_blocks'].nodes.items(): + downblock.convert_name = downid + downsamplers = downblock.nodes.get('downsamplers', SegTree()) + dup_convert_name_recursive(downblock['resnets'], 'block') + # check for conv_shortcuts here + for resid, resnet in downblock['resnets'].nodes.items(): + if resnet.nodes.get('conv_shortcut') is not None: + resnet.nodes['conv_shortcut'].convert_name = 'nin_shortcut' + if len(downsamplers) > 0: + dup_convert_name_recursive(downsamplers['0'], 'downsample') + + # map the `mid_block` ( NeMo's mid layer is hardcoded in terms of number of modules) + encoder['mid_block'].convert_name = 'mid' + dup_convert_name_recursive(encoder[f'mid_block.resnets.0'], 'block_1') + dup_convert_name_recursive(encoder[f'mid_block.resnets.1'], 'block_2') + + # attention part + att = encoder['mid_block.attentions.0'] + att.convert_name = 'attn_1' + dup_convert_name_recursive(att['group_norm'], 'norm') + dup_convert_name_recursive(att['to_k'], 'k') + dup_convert_name_recursive(att['to_q'], 'q') + dup_convert_name_recursive(att['to_v'], 'v') + dup_convert_name_recursive(att['to_out.0'], 'proj_out') + + +def convert_decoder(hf_tree: SegTree): + decoder = hf_tree['decoder'] + decoder.convert_name = 'decoder' + dup_convert_name_recursive(decoder['conv_in'], 'conv_in') + dup_convert_name_recursive(decoder['conv_out'], 'conv_out') + dup_convert_name_recursive(decoder['conv_norm_out'], 'norm_out') + # each block contains resnets and downsamplers + # map the `mid_block` ( NeMo's mid layer is hardcoded in terms of number of modules) + decoder['mid_block'].convert_name = 'mid' + dup_convert_name_recursive(decoder[f'mid_block.resnets.0'], 'block_1') + dup_convert_name_recursive(decoder[f'mid_block.resnets.1'], 'block_2') + att = decoder['mid_block.attentions.0'] + att.convert_name = 'attn_1' + dup_convert_name_recursive(att['group_norm'], 'norm') + dup_convert_name_recursive(att['to_k'], 'k') + dup_convert_name_recursive(att['to_q'], 'q') + dup_convert_name_recursive(att['to_v'], 'v') + dup_convert_name_recursive(att['to_out.0'], 'proj_out') + + # up blocks contain resnets and upsamplers + decoder['up_blocks'].convert_name = 'up' + num_up_blocks = len(decoder['up_blocks']) + for upid, upblock in decoder['up_blocks'].nodes.items(): + upblock.convert_name = str(num_up_blocks - 1 - int(upid)) + upsamplers = upblock.nodes.get('upsamplers', SegTree()) + dup_convert_name_recursive(upblock['resnets'], 'block') + # check for conv_shortcuts here + for resid, resnet in upblock['resnets'].nodes.items(): + if resnet.nodes.get('conv_shortcut') is not None: + resnet.nodes['conv_shortcut'].convert_name = 'nin_shortcut' + if len(upsamplers) > 0: + dup_convert_name_recursive(upsamplers['0'], 'upsample') + + +def convert(args): + logging.info(f"loading checkpoint {args.input_name_or_path}") + _, hf_ckpt = load_hf_ckpt(args.input_name_or_path, args) + hf_tree = model_to_tree(hf_ckpt) + + if args.model == 'unet': + logging.info("converting unet...") + convert_input_keys(hf_tree) + convert_cond_keys(hf_tree) + convert_middle_keys(hf_tree) + convert_output_keys(hf_tree) + convert_finalout_keys(hf_tree) + # get mapping + + elif args.model == 'vae': + logging.info("converting vae...") + dup_convert_name_recursive(hf_tree['quant_conv'], 'quant_conv') + dup_convert_name_recursive(hf_tree['post_quant_conv'], 'post_quant_conv') + convert_encoder(hf_tree) + convert_decoder(hf_tree) + + else: + logging.error("incorrect model specification.") + return + + # check mapping + mapping = hf_to_nemo_mapping(hf_tree) + if len(mapping) != len(hf_ckpt.keys()): + logging.warning("not all keys are matched properly.") + nemo_ckpt = {} + + for hf_key, nemo_key in mapping.items(): + nemo_ckpt[nemo_key] = hf_ckpt[hf_key] + torch.save(nemo_ckpt, args.output_path) + logging.info(f"Saved nemo file to {args.output_path}") + + +if __name__ == '__main__': + args = get_args() + convert(args)