Skip to content

Commit

Permalink
Min-SNR: SDXL support (experimental)
Browse files Browse the repository at this point in the history
  • Loading branch information
bghira committed Aug 13, 2023
1 parent b014eac commit 9f6bbd4
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 7 deletions.
9 changes: 6 additions & 3 deletions sdxl-env.sh.example
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,15 @@ export TRAINER_EXTRA_ARGS="--allow_tf32 --use_8bit_adam --use_ema" # anything y
export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --enable_xformers_memory_efficient_attention --use_original_images=true"
export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --gradient_checkpointing --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS}"

## For offset noise training.
## For offset noise training:
#export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --offset_noise --noise_offset=0.02"

## For noise input pertubation - adds extra noise, randomly. This is separate from offset noise.
## For noise input pertubation - adds extra noise, randomly. This is separate from offset noise:
#export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --input_pertubation=0.01"

## For terminal SNR training:
#export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --prediction_type=v_prediction --rescale_betas_zero_snr"
#export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --training_scheduler_timestep_spacing=leading --inference_scheduler_timestep_spacing=trailing"
#export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --training_scheduler_timestep_spacing=leading --inference_scheduler_timestep_spacing=trailing"

## For experimental min-SNR weighted loss training (5 is suggested value by the original researchers):
#export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --snr_gamma=5.0"
5 changes: 4 additions & 1 deletion train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,10 @@ def main(args):
).min(dim=1)[0]
/ snr
)

# An experimental strategy for fixing min-SNR with zero terminal SNR is to set loss weighting to 1 when
# any positional tensors have an SNR of zero. This is to preserve their loss values and also to hopefully
# prevent the explosion of gradients or NaNs due to the presence of very small numbers.
mse_loss_weights[snr == 0] = 1.0
if torch.any(torch.isnan(mse_loss_weights)):
print("mse_loss_weights contains NaN values")
# We first calculate the original loss. Then we mean over the non-batch dimensions and
Expand Down
47 changes: 44 additions & 3 deletions train_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from helpers.vae_cache import VAECache
from helpers.arguments import parse_args
from helpers.custom_schedule import get_polynomial_decay_schedule_with_warmup
from helpers.min_snr_gamma import compute_snr

logger = logging.getLogger()
filelock_logger = logging.getLogger("filelock")
Expand Down Expand Up @@ -874,9 +875,49 @@ def collate_fn(examples):
added_cond_kwargs=added_cond_kwargs,
).sample

training_logger.debug(f"Calculating loss")
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

if args.snr_gamma is None:
training_logger.debug(f"Calculating loss")
loss = F.mse_loss(
model_pred.float(), target.float(), reduction="mean"
)
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
training_logger.debug(f"Using min-SNR loss")
snr = compute_snr(timesteps, noise_scheduler)

if torch.any(torch.isnan(snr)):
training_logger.error("snr contains NaN values")
if torch.any(snr == 0):
training_logger.error("snr contains zero values")
training_logger.debug(f'Calculating MSE loss weights using SNR as divisor')
mse_loss_weights = (
torch.stack(
[snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1
).min(dim=1)[0]
/ snr
)
# An experimental strategy for fixing min-SNR with zero terminal SNR is to set loss weighting to 1 when
# any positional tensors have an SNR of zero. This is to preserve their loss values and also to hopefully
# prevent the explosion of gradients or NaNs due to the presence of very small numbers.
mse_loss_weights[snr == 0] = 1.0
if torch.any(torch.isnan(mse_loss_weights)):
training_logger.error("mse_loss_weights contains NaN values")
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
training_logger.debug(f'Calculating original MSE loss without reduction')
loss = F.mse_loss(
model_pred.float(), target.float(), reduction="none"
)
training_logger.debug(f'Calculating SNR-weighted MSE loss')
loss = (
loss.mean(dim=list(range(1, len(loss.shape))))
* mse_loss_weights
)
training_logger.debug(f'Reducing loss via mean')
loss = loss.mean()
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
Expand Down

0 comments on commit 9f6bbd4

Please sign in to comment.