diff --git a/helpers/arguments.py b/helpers/arguments.py index 55a65cdc..e62097e7 100644 --- a/helpers/arguments.py +++ b/helpers/arguments.py @@ -142,13 +142,13 @@ def parse_args(input_args=None): parser.add_argument( "--seen_state_path", type=str, - default=None, + default='seen_state.json', help="Where the JSON document containing the state of the seen images is stored. This helps ensure we do not repeat images too many times.", ) parser.add_argument( "--state_path", type=str, - default=None, + default='training_state.json', help="A JSON document containing the current state of training, will be placed here.", ) parser.add_argument( diff --git a/sdxl-env.sh.example b/sdxl-env.sh.example index 5625db0d..69f2bd7b 100644 --- a/sdxl-env.sh.example +++ b/sdxl-env.sh.example @@ -89,12 +89,15 @@ export TRAINER_EXTRA_ARGS="--allow_tf32 --use_8bit_adam --use_ema" # anything y export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --enable_xformers_memory_efficient_attention --use_original_images=true" export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --gradient_checkpointing --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS}" -## For offset noise training. +## For offset noise training: #export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --offset_noise --noise_offset=0.02" -## For noise input pertubation - adds extra noise, randomly. This is separate from offset noise. +## For noise input pertubation - adds extra noise, randomly. This is separate from offset noise: #export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --input_pertubation=0.01" ## For terminal SNR training: #export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --prediction_type=v_prediction --rescale_betas_zero_snr" -#export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --training_scheduler_timestep_spacing=leading --inference_scheduler_timestep_spacing=trailing" \ No newline at end of file +#export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --training_scheduler_timestep_spacing=leading --inference_scheduler_timestep_spacing=trailing" + +## For experimental min-SNR weighted loss training (5 is suggested value by the original researchers): +#export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --snr_gamma=5.0" \ No newline at end of file diff --git a/train_dreambooth.py b/train_dreambooth.py index 48dbe8bf..66be0f40 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -567,7 +567,10 @@ def main(args): ).min(dim=1)[0] / snr ) - + # An experimental strategy for fixing min-SNR with zero terminal SNR is to set loss weighting to 1 when + # any positional tensors have an SNR of zero. This is to preserve their loss values and also to hopefully + # prevent the explosion of gradients or NaNs due to the presence of very small numbers. + mse_loss_weights[snr == 0] = 1.0 if torch.any(torch.isnan(mse_loss_weights)): print("mse_loss_weights contains NaN values") # We first calculate the original loss. Then we mean over the non-batch dimensions and diff --git a/train_sdxl.py b/train_sdxl.py index 94bdc7f0..4f42027d 100644 --- a/train_sdxl.py +++ b/train_sdxl.py @@ -34,6 +34,7 @@ from helpers.vae_cache import VAECache from helpers.arguments import parse_args from helpers.custom_schedule import get_polynomial_decay_schedule_with_warmup +from helpers.min_snr_gamma import compute_snr logger = logging.getLogger() filelock_logger = logging.getLogger("filelock") @@ -515,7 +516,6 @@ def collate_fn(examples): [add_text_embeds_all for _ in range(1)], dim=0 ) - logger.debug(f"Returning collate_fn results.") return { "pixel_values": pixel_values, "prompt_embeds": prompt_embeds_all, @@ -833,7 +833,12 @@ def collate_fn(examples): # The chance of this happening is dictated by the caption_dropout_probability. if random.random() < args.caption_dropout_probability: training_logger.debug(f'Caption dropout triggered.') - encoder_hidden_states = embed_cache.compute_embeddings_for_prompts([""]) + ( + batch["prompt_embeds_all"], + batch["add_text_embeds_all"], + ) = embed_cache.compute_embeddings_for_prompts([""]) + + # Conditioning dropout not yet supported. add_text_embeds = batch["add_text_embeds"] training_logger.debug( f"Encoder hidden states: {encoder_hidden_states.shape}" @@ -870,9 +875,49 @@ def collate_fn(examples): added_cond_kwargs=added_cond_kwargs, ).sample - training_logger.debug(f"Calculating loss") - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - + if args.snr_gamma is None: + training_logger.debug(f"Calculating loss") + loss = F.mse_loss( + model_pred.float(), target.float(), reduction="mean" + ) + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + training_logger.debug(f"Using min-SNR loss") + snr = compute_snr(timesteps, noise_scheduler) + + if torch.any(torch.isnan(snr)): + training_logger.error("snr contains NaN values") + if torch.any(snr == 0): + training_logger.error("snr contains zero values") + training_logger.debug(f'Calculating MSE loss weights using SNR as divisor') + mse_loss_weights = ( + torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / snr + ) + # An experimental strategy for fixing min-SNR with zero terminal SNR is to set loss weighting to 1 when + # any positional tensors have an SNR of zero. This is to preserve their loss values and also to hopefully + # prevent the explosion of gradients or NaNs due to the presence of very small numbers. + mse_loss_weights[snr == 0] = 1.0 + if torch.any(torch.isnan(mse_loss_weights)): + training_logger.error("mse_loss_weights contains NaN values") + # We first calculate the original loss. Then we mean over the non-batch dimensions and + # rebalance the sample-wise losses with their respective loss weights. + # Finally, we take the mean of the rebalanced loss. + training_logger.debug(f'Calculating original MSE loss without reduction') + loss = F.mse_loss( + model_pred.float(), target.float(), reduction="none" + ) + training_logger.debug(f'Calculating SNR-weighted MSE loss') + loss = ( + loss.mean(dim=list(range(1, len(loss.shape)))) + * mse_loss_weights + ) + training_logger.debug(f'Reducing loss via mean') + loss = loss.mean() # Gather the losses across all processes for logging (if we use distributed training). avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() train_loss += avg_loss.item() / args.gradient_accumulation_steps diff --git a/train_sdxl.sh b/train_sdxl.sh index 660d622b..5d12ad86 100644 --- a/train_sdxl.sh +++ b/train_sdxl.sh @@ -8,7 +8,7 @@ accelerate launch ${ACCELERATE_EXTRA_ARGS} --mixed_precision="${MIXED_PRECISION} --pretrained_model_name_or_path="${MODEL_NAME}" \ --resume_from_checkpoint="${RESUME_CHECKPOINT}" \ --learning_rate="${LEARNING_RATE}" --lr_scheduler="${LR_SCHEDULE}" --seed "${TRAINING_SEED}" \ - --instance_data_dir="${INSTANCE_DIR}" --seen_state_path="${SEEN_STATE_PATH}" \ + --instance_data_dir="${INSTANCE_DIR}" --seen_state_path="${SEEN_STATE_PATH}" --state_path="${STATE_PATH}" \ ${DEBUG_EXTRA_ARGS} --mixed_precision="${MIXED_PRECISION}" --vae_dtype="${MIXED_PRECISION}" ${TRAINER_EXTRA_ARGS} \ --train_batch="${TRAIN_BATCH_SIZE}" --caption_dropout_probability=${CAPTION_DROPOUT_PROBABILITY} \ --validation_prompt="${VALIDATION_PROMPT}" --num_validation_images=1 \