Skip to content

Commit

Permalink
Merge pull request #34 from bghira/main
Browse files Browse the repository at this point in the history
Min-SNR: using the minimum SNR as an MSE weighted loss
  • Loading branch information
bghira authored Aug 14, 2023
2 parents f27f8e1 + 9f6bbd4 commit d527b8f
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 12 deletions.
4 changes: 2 additions & 2 deletions helpers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,13 @@ def parse_args(input_args=None):
parser.add_argument(
"--seen_state_path",
type=str,
default=None,
default='seen_state.json',
help="Where the JSON document containing the state of the seen images is stored. This helps ensure we do not repeat images too many times.",
)
parser.add_argument(
"--state_path",
type=str,
default=None,
default='training_state.json',
help="A JSON document containing the current state of training, will be placed here.",
)
parser.add_argument(
Expand Down
9 changes: 6 additions & 3 deletions sdxl-env.sh.example
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,15 @@ export TRAINER_EXTRA_ARGS="--allow_tf32 --use_8bit_adam --use_ema" # anything y
export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --enable_xformers_memory_efficient_attention --use_original_images=true"
export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --gradient_checkpointing --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS}"

## For offset noise training.
## For offset noise training:
#export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --offset_noise --noise_offset=0.02"

## For noise input pertubation - adds extra noise, randomly. This is separate from offset noise.
## For noise input pertubation - adds extra noise, randomly. This is separate from offset noise:
#export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --input_pertubation=0.01"

## For terminal SNR training:
#export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --prediction_type=v_prediction --rescale_betas_zero_snr"
#export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --training_scheduler_timestep_spacing=leading --inference_scheduler_timestep_spacing=trailing"
#export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --training_scheduler_timestep_spacing=leading --inference_scheduler_timestep_spacing=trailing"

## For experimental min-SNR weighted loss training (5 is suggested value by the original researchers):
#export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --snr_gamma=5.0"
5 changes: 4 additions & 1 deletion train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,10 @@ def main(args):
).min(dim=1)[0]
/ snr
)

# An experimental strategy for fixing min-SNR with zero terminal SNR is to set loss weighting to 1 when
# any positional tensors have an SNR of zero. This is to preserve their loss values and also to hopefully
# prevent the explosion of gradients or NaNs due to the presence of very small numbers.
mse_loss_weights[snr == 0] = 1.0
if torch.any(torch.isnan(mse_loss_weights)):
print("mse_loss_weights contains NaN values")
# We first calculate the original loss. Then we mean over the non-batch dimensions and
Expand Down
55 changes: 50 additions & 5 deletions train_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from helpers.vae_cache import VAECache
from helpers.arguments import parse_args
from helpers.custom_schedule import get_polynomial_decay_schedule_with_warmup
from helpers.min_snr_gamma import compute_snr

logger = logging.getLogger()
filelock_logger = logging.getLogger("filelock")
Expand Down Expand Up @@ -515,7 +516,6 @@ def collate_fn(examples):
[add_text_embeds_all for _ in range(1)], dim=0
)

logger.debug(f"Returning collate_fn results.")
return {
"pixel_values": pixel_values,
"prompt_embeds": prompt_embeds_all,
Expand Down Expand Up @@ -833,7 +833,12 @@ def collate_fn(examples):
# The chance of this happening is dictated by the caption_dropout_probability.
if random.random() < args.caption_dropout_probability:
training_logger.debug(f'Caption dropout triggered.')
encoder_hidden_states = embed_cache.compute_embeddings_for_prompts([""])
(
batch["prompt_embeds_all"],
batch["add_text_embeds_all"],
) = embed_cache.compute_embeddings_for_prompts([""])

# Conditioning dropout not yet supported.
add_text_embeds = batch["add_text_embeds"]
training_logger.debug(
f"Encoder hidden states: {encoder_hidden_states.shape}"
Expand Down Expand Up @@ -870,9 +875,49 @@ def collate_fn(examples):
added_cond_kwargs=added_cond_kwargs,
).sample

training_logger.debug(f"Calculating loss")
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

if args.snr_gamma is None:
training_logger.debug(f"Calculating loss")
loss = F.mse_loss(
model_pred.float(), target.float(), reduction="mean"
)
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
training_logger.debug(f"Using min-SNR loss")
snr = compute_snr(timesteps, noise_scheduler)

if torch.any(torch.isnan(snr)):
training_logger.error("snr contains NaN values")
if torch.any(snr == 0):
training_logger.error("snr contains zero values")
training_logger.debug(f'Calculating MSE loss weights using SNR as divisor')
mse_loss_weights = (
torch.stack(
[snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1
).min(dim=1)[0]
/ snr
)
# An experimental strategy for fixing min-SNR with zero terminal SNR is to set loss weighting to 1 when
# any positional tensors have an SNR of zero. This is to preserve their loss values and also to hopefully
# prevent the explosion of gradients or NaNs due to the presence of very small numbers.
mse_loss_weights[snr == 0] = 1.0
if torch.any(torch.isnan(mse_loss_weights)):
training_logger.error("mse_loss_weights contains NaN values")
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
training_logger.debug(f'Calculating original MSE loss without reduction')
loss = F.mse_loss(
model_pred.float(), target.float(), reduction="none"
)
training_logger.debug(f'Calculating SNR-weighted MSE loss')
loss = (
loss.mean(dim=list(range(1, len(loss.shape))))
* mse_loss_weights
)
training_logger.debug(f'Reducing loss via mean')
loss = loss.mean()
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
Expand Down
2 changes: 1 addition & 1 deletion train_sdxl.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ accelerate launch ${ACCELERATE_EXTRA_ARGS} --mixed_precision="${MIXED_PRECISION}
--pretrained_model_name_or_path="${MODEL_NAME}" \
--resume_from_checkpoint="${RESUME_CHECKPOINT}" \
--learning_rate="${LEARNING_RATE}" --lr_scheduler="${LR_SCHEDULE}" --seed "${TRAINING_SEED}" \
--instance_data_dir="${INSTANCE_DIR}" --seen_state_path="${SEEN_STATE_PATH}" \
--instance_data_dir="${INSTANCE_DIR}" --seen_state_path="${SEEN_STATE_PATH}" --state_path="${STATE_PATH}" \
${DEBUG_EXTRA_ARGS} --mixed_precision="${MIXED_PRECISION}" --vae_dtype="${MIXED_PRECISION}" ${TRAINER_EXTRA_ARGS} \
--train_batch="${TRAIN_BATCH_SIZE}" --caption_dropout_probability=${CAPTION_DROPOUT_PROBABILITY} \
--validation_prompt="${VALIDATION_PROMPT}" --num_validation_images=1 \
Expand Down

0 comments on commit d527b8f

Please sign in to comment.