Skip to content

Commit

Permalink
Merge pull request #33 from bghira/main
Browse files Browse the repository at this point in the history
Logging luminance for training data; Input perturbation for SDXL; Caption dropout fix for SDXL; Save 'seen' state for SDXL aspect bucketing
  • Loading branch information
bghira authored Aug 13, 2023
2 parents d103496 + 77f5abc commit f27f8e1
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 15 deletions.
8 changes: 7 additions & 1 deletion helpers/image_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,10 @@ def calculate_luminance(img: Image):

# Return average luminance for the entire image
avg_luminance = sum(luminance_values) / len(luminance_values)
return avg_luminance
return avg_luminance

def calculate_batch_luminance(imgs: list):
luminance_values = []
for img in imgs:
luminance_values.append(calculate_luminance(img))
return sum(luminance_values) / len(luminance_values)
10 changes: 6 additions & 4 deletions sdxl-env.sh.example
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ export LR_WARMUP_STEPS=$((MAX_NUM_STEPS / 10))
# Adjust this for your GPU memory size.
export TRAIN_BATCH_SIZE=10

# Currently not implemented.
NOISE_OFFSET=0.0

# Validation image settings.
VALIDATION_GUIDANCE=7.5
VALIDATION_GUIDANCE_RESCALE=0.0
Expand Down Expand Up @@ -92,7 +89,12 @@ export TRAINER_EXTRA_ARGS="--allow_tf32 --use_8bit_adam --use_ema" # anything y
export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --enable_xformers_memory_efficient_attention --use_original_images=true"
export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --gradient_checkpointing --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS}"

## For terminal SNR training:
## For offset noise training.
#export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --offset_noise --noise_offset=0.02"

## For noise input pertubation - adds extra noise, randomly. This is separate from offset noise.
#export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --input_pertubation=0.01"

## For terminal SNR training:
#export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --prediction_type=v_prediction --rescale_betas_zero_snr"
#export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --training_scheduler_timestep_spacing=leading --inference_scheduler_timestep_spacing=trailing"
6 changes: 4 additions & 2 deletions train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,14 +493,16 @@ def main(args):
latents = latents * vae.config.scaling_factor

# Sample noise that we'll add to the latents - args.noise_offset might need to be set to 0.1 by default.
noise = None
if args.offset_noise:
noise = torch.randn_like(latents) + args.noise_offset * torch.randn(
latents.shape[0], latents.shape[1], 1, 1, device=latents.device
)
else:
noise = torch.randn_like(latents)
if args.input_pertubation:
new_noise = noise + args.input_pertubation * torch.randn_like(noise)

else:
elif noise is None:
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
Expand Down
44 changes: 37 additions & 7 deletions train_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from helpers.dreambooth_dataset import DreamBoothDataset
from helpers.state_tracker import StateTracker
from helpers.sdxl_embeds import TextEmbeddingCache
from helpers.image_tools import calculate_luminance
from helpers.image_tools import calculate_luminance, calculate_batch_luminance
from helpers.vae_cache import VAECache
from helpers.arguments import parse_args
from helpers.custom_schedule import get_polynomial_decay_schedule_with_warmup
Expand Down Expand Up @@ -474,7 +474,10 @@ def collate_fn(examples):
logger.debug(f"Not training, returning nothing from collate_fn")
return
training_logger.debug(f"Examples: {examples}")

training_logger.debug(f"Computing luminance for input batch")
batch_luminance = calculate_batch_luminance(
[example["instance_images"] for example in examples]
)
# Initialize the VAE Cache if it doesn't exist
global vaecache
if "vaecache" not in globals():
Expand Down Expand Up @@ -518,6 +521,7 @@ def collate_fn(examples):
"prompt_embeds": prompt_embeds_all,
"add_text_embeds": add_text_embeds_all,
"add_time_ids": compute_time_ids(width, height),
"luminance": batch_luminance,
}

# DataLoaders creation:
Expand Down Expand Up @@ -752,6 +756,7 @@ def collate_fn(examples):
logger.debug(f"Starting into epoch: {epoch}")
unet.train()
train_loss = 0.0
training_luminance_values = []
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if (
Expand All @@ -774,14 +779,30 @@ def collate_fn(examples):
if batch is None:
logging.warning(f"Burning a None size batch.")
continue

# Add the current batch of training data's avg luminance to a list.
training_luminance_values.append(batch["luminance"])

with accelerator.accumulate(unet):
training_logger.debug(f"Beginning another step.")
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
training_logger.debug("Moved pixels to accelerator.")
latents = pixel_values
# Sample noise that we'll add to the latents
training_logger.debug(f"Sampling random noise")
noise = torch.randn_like(latents)
# Sample noise that we'll add to the latents - args.noise_offset might need to be set to 0.1 by default.
noise = None
if args.offset_noise:
noise = torch.randn_like(latents) + args.noise_offset * torch.randn(
latents.shape[0], latents.shape[1], 1, 1, device=latents.device
)
else:
noise = torch.randn_like(latents)
if args.input_pertubation:
new_noise = noise + args.input_pertubation * torch.randn_like(noise)
elif noise is None:
noise = torch.randn_like(latents)

bsz = latents.shape[0]
training_logger.debug(f"Working on batch size: {bsz}")
# Sample a random timestep for each image
Expand All @@ -796,19 +817,23 @@ def collate_fn(examples):

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
if args.input_pertubation:
noisy_latents = noise_scheduler.add_noise(
latents, new_noise, timesteps
)
else:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
training_logger.debug(
f"Generated noisy latent frame from latents and noise."
)

# SDXL additional inputs - probabilistic dropout
encoder_hidden_states = batch["prompt_embeds"]
if args.caption_dropout_probability is not None and args.caption_dropout_probability > 0:
# When using caption dropout, we will use the null embed instead of prompt embeds.
# The chance of this happening is dictated by the caption_dropout_probability.
if random.random() < args.caption_dropout_probability:
training_logger.debug(f'Caption dropout triggered.')
encoder_hidden_states = embed_cache.get_embeddings_for_prompts([""])
encoder_hidden_states = embed_cache.compute_embeddings_for_prompts([""])
add_text_embeds = batch["add_text_embeds"]
training_logger.debug(
f"Encoder hidden states: {encoder_hidden_states.shape}"
Expand Down Expand Up @@ -870,7 +895,11 @@ def collate_fn(examples):
ema_unet.step(unet.parameters())
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss, "learning_rate": lr_scheduler.get_last_lr()[0]}, step=global_step)
# Average out the luminance values of each batch, so that we can store that in this step.
avg_training_data_luminance = sum(training_luminance_values) / len(training_luminance_values)
accelerator.log({"train_luminance": avg_training_data_luminance, "train_loss": train_loss, "learning_rate": lr_scheduler.get_last_lr()[0]}, step=global_step)
# Reset some values for the next go.
training_luminance_values = []
train_loss = 0.0

if global_step % args.checkpointing_steps == 0:
Expand Down Expand Up @@ -909,6 +938,7 @@ def collate_fn(examples):
args.output_dir, f"checkpoint-{global_step}"
)
accelerator.save_state(save_path)
custom_balanced_sampler.save_state()
logger.info(f"Saved state to {save_path}")

logs = {
Expand Down
2 changes: 1 addition & 1 deletion train_sdxl.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ accelerate launch ${ACCELERATE_EXTRA_ARGS} --mixed_precision="${MIXED_PRECISION}
--resolution="${RESOLUTION}" --validation_resolution="${RESOLUTION}" \
--checkpointing_steps="${CHECKPOINTING_STEPS}" --checkpoints_total_limit="${CHECKPOINTING_LIMIT}" \
--validation_steps="${VALIDATION_STEPS}" --tracker_run_name="${TRACKER_RUN_NAME}" --num_train_epochs="${NUM_EPOCHS}" \
--noise_offset="${NOISE_OFFSET}" --validation_guidance="${VALIDATION_GUIDANCE}" --validation_guidance_rescale="${VALIDATION_GUIDANCE_RESCALE}"
--validation_guidance="${VALIDATION_GUIDANCE}" --validation_guidance_rescale="${VALIDATION_GUIDANCE_RESCALE}"

exit 0

0 comments on commit f27f8e1

Please sign in to comment.