Skip to content

Commit

Permalink
Merge pull request #26 from bghira/main
Browse files Browse the repository at this point in the history
Validation image settings now exist
  • Loading branch information
bghira authored Aug 10, 2023
2 parents ae1c5cb + fae3405 commit c550df5
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 13 deletions.
12 changes: 12 additions & 0 deletions helpers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,18 @@ def parse_args(input_args=None):
default=5,
help="Run validation every X epochs.",
)
parser.add_argument(
'--validation_guidance',
type=float,
default=7.5,
help="CFG value for validation images. Default: 7.5",
)
parser.add_argument(
'--validation_guidance_rescale',
type=float,
default=0.0,
help="CFG rescale value for validation images. Default: 0.0, max 1.0",
)
parser.add_argument(
"--freeze_encoder_before",
type=int,
Expand Down
12 changes: 6 additions & 6 deletions helpers/aspect_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ def remove_image(self, image_path, bucket):

def handle_small_image(self, image_path, bucket):
logger.warning(f"Image too small: DELETING image and continuing search.")
try:
os.remove(image_path)
except Exception as e:
logger.warning(
f"The image was already deleted. Another GPU must have gotten to it."
)
# try:
# os.remove(image_path)
# except Exception as e:
# logger.warning(
# f"The image was already deleted. Another GPU must have gotten to it."
# )
self.remove_image(image_path, bucket)

def handle_incorrect_bucket(self, image_path, bucket, actual_bucket):
Expand Down
8 changes: 8 additions & 0 deletions sdxl-env.sh.example
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ export LR_WARMUP_STEPS=$((MAX_NUM_STEPS / 10))
# Adjust this for your GPU memory size.
export TRAIN_BATCH_SIZE=10

# Currently not implemented.
NOISE_OFFSET=0.0

# Validation image settings.
VALIDATION_GUIDANCE=7.5
VALIDATION_GUIDANCE_RESCALE=0.0


# Leave these alone unless you know what you are doing.
export RESOLUTION=1024
export GRADIENT_ACCUMULATION_STEPS=4 # Yes, it slows training down. No, you don't want to change this.
Expand Down
12 changes: 7 additions & 5 deletions train_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ def collate_fn(examples):
ema_unet.step(unet.parameters())
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
accelerator.log({"train_loss": train_loss, "learning_rate": lr_scheduler.get_last_lr()[0]}, step=global_step)
train_loss = 0.0

if global_step % args.checkpointing_steps == 0:
Expand Down Expand Up @@ -945,8 +945,9 @@ def collate_fn(examples):
negative_prompt_embeds=validation_negative_prompt_embeds,
negative_pooled_prompt_embeds=validation_negative_pooled_embeds,
num_images_per_prompt=args.num_validation_images,
num_inference_steps=20,
guidance_scale=7,
num_inference_steps=30,
guidance_scale=args.validation_guidance,
guidance_rescale=args.validation_guidance_rescale,
generator=validation_generator,
height=args.validation_resolution,
width=args.validation_resolution,
Expand Down Expand Up @@ -1020,8 +1021,9 @@ def collate_fn(examples):
negative_prompt_embeds=validation_negative_prompt_embeds,
negative_pooled_prompt_embeds=validation_negative_pooled_embeds,
num_images_per_prompt=args.num_validation_images,
num_inference_steps=20,
guidance_scale=7,
num_inference_steps=30,
guidance_scale=args.validation_guidance,
guidance_rescale=args.validation_guidance_rescale,
generator=validation_generator,
height=args.validation_resolution,
width=args.validation_resolution,
Expand Down
5 changes: 3 additions & 2 deletions train_sdxl.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ source sdxl-env.sh.example
accelerate launch ${ACCELERATE_EXTRA_ARGS} --mixed_precision="${MIXED_PRECISION}" --num_processes="${TRAINING_NUM_PROCESSES}" --num_machines="${TRAINING_NUM_MACHINES}" --dynamo_backend="${TRAINING_DYNAMO_BACKEND}" train_sdxl.py \
--pretrained_model_name_or_path="${MODEL_NAME}" \
--resume_from_checkpoint="${RESUME_CHECKPOINT}" \
--learning_rate="${LEARNING_RATE}" --seed "${TRAINING_SEED}" \
--learning_rate="${LEARNING_RATE}" --lr_scheduler="${LR_SCHEDULE}" --seed "${TRAINING_SEED}" \
--instance_data_dir="${INSTANCE_DIR}" --seen_state_path="${SEEN_STATE_PATH}" \
${DEBUG_EXTRA_ARGS} --mixed_precision="${MIXED_PRECISION}" --vae_dtype="${MIXED_PRECISION}" ${TRAINER_EXTRA_ARGS} \
--train_batch="${TRAIN_BATCH_SIZE}" \
--validation_prompt="${VALIDATION_PROMPT}" --num_validation_images=1 \
--resolution="${RESOLUTION}" --validation_resolution="${RESOLUTION}" \
--checkpointing_steps="${CHECKPOINTING_STEPS}" --checkpoints_total_limit="${CHECKPOINTING_LIMIT}" \
--validation_steps="${VALIDATION_STEPS}" --tracker_run_name="${TRACKER_RUN_NAME}" --num_train_epochs="${NUM_EPOCHS}"
--validation_steps="${VALIDATION_STEPS}" --tracker_run_name="${TRACKER_RUN_NAME}" --num_train_epochs="${NUM_EPOCHS}" \
--noise_offset="${NOISE_OFFSET}" --validation_guidance="${VALIDATION_GUIDANCE}" --validation_guidance_rescale="${VALIDATION_GUIDANCE_RESCALE}"

exit 0

0 comments on commit c550df5

Please sign in to comment.