Skip to content

Commit

Permalink
Merge pull request #1083 from bghira/main
Browse files Browse the repository at this point in the history
merge
  • Loading branch information
bghira authored Oct 22, 2024
2 parents 056d1b9 + aec04fc commit 71bea97
Show file tree
Hide file tree
Showing 11 changed files with 384 additions and 516 deletions.
8 changes: 6 additions & 2 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,8 +561,11 @@ def configure_multi_databackend(args: dict, accelerator, text_encoders, tokenize
raise ValueError("Each dataset needs a unique 'id' field.")
info_log(f"Configuring data backend: {backend['id']}")
conditioning_type = backend.get("conditioning_type")
if backend.get("dataset_type") == "conditioning" or conditioning_type is not None:
backend['dataset_type'] = 'conditioning'
if (
backend.get("dataset_type") == "conditioning"
or conditioning_type is not None
):
backend["dataset_type"] = "conditioning"
resolution_type = backend.get("resolution_type", args.resolution_type)
if resolution_type == "pixel_area":
pixel_edge_length = backend.get("resolution", int(args.resolution))
Expand Down Expand Up @@ -897,6 +900,7 @@ def configure_multi_databackend(args: dict, accelerator, text_encoders, tokenize
),
instance_prompt=backend.get("instance_prompt", args.instance_prompt),
conditioning_type=conditioning_type,
is_regularisation_data=is_regularisation_data,
)
if init_backend["sampler"].caption_strategy == "parquet":
configure_parquet_database(backend, args, init_backend["data_backend"])
Expand Down
14 changes: 10 additions & 4 deletions helpers/multiaspect/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
prepend_instance_prompt=False,
instance_prompt: str = None,
conditioning_type: str = None,
is_regularisation_data: bool = False,
):
"""
Initializes the sampler with provided settings.
Expand Down Expand Up @@ -67,6 +68,7 @@ def __init__(
f"Unknown conditioning image type: {conditioning_type}"
)
self.conditioning_type = conditioning_type
self.is_regularisation_data = is_regularisation_data

self.rank_info = rank_info()
self.accelerator = accelerator
Expand Down Expand Up @@ -376,15 +378,19 @@ def log_state(self, show_rank: bool = True, alt_stats: bool = False):
# We don't know the direct count without more work, so we'll estimate it here for multi-GPU training.
total_image_count *= self.accelerator.num_processes
total_image_count = f"~{total_image_count}"
data_backend_config = StateTracker.get_data_backend_config(self.id)
printed_state = (
f"- Repeats: {StateTracker.get_data_backend_config(self.id).get('repeats', 0)}\n"
f"- Repeats: {data_backend_config.get('repeats', 0)}\n"
f"- Total number of images: {total_image_count}\n"
f"- Total number of aspect buckets: {len(self.buckets)}\n"
f"- Resolution: {self.resolution} {'megapixels' if self.resolution_type == 'area' else 'px'}\n"
f"- Cropped: {StateTracker.get_data_backend_config(self.id).get('crop')}\n"
f"- Crop style: {'None' if not StateTracker.get_data_backend_config(self.id).get('crop') else StateTracker.get_data_backend_config(self.id).get('crop_style')}\n"
f"- Crop aspect: {'None' if not StateTracker.get_data_backend_config(self.id).get('crop') else StateTracker.get_data_backend_config(self.id).get('crop_aspect')}\n"
f"- Cropped: {data_backend_config.get('crop')}\n"
f"- Crop style: {'None' if not data_backend_config.get('crop') else data_backend_config.get('crop_style')}\n"
f"- Crop aspect: {'None' if not data_backend_config.get('crop') else data_backend_config.get('crop_aspect')}\n"
f"- Used for regularisation data: {'Yes' if self.is_regularisation_data else 'No'}\n"
)
if self.conditioning_type:
printed_state += f"- Conditioning type: {self.conditioning_type}\n"
else:
# Return a snapshot of the current state during training.
printed_state = (
Expand Down
4 changes: 4 additions & 0 deletions helpers/publishing/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ def flux_schedule_info(args):
else " (no special parameters set)"
)

return output_str


def save_model_card(
repo_id: str,
Expand All @@ -281,6 +283,8 @@ def save_model_card(
logger.debug(f"Validating from prompts: {validation_prompts}")
assets_folder = os.path.join(repo_folder, "assets")
optimizer_config = StateTracker.get_args().optimizer_config
if optimizer_config is None:
optimizer_config = ""
os.makedirs(assets_folder, exist_ok=True)
datasets_str = ""
for dataset in StateTracker.get_data_backends().keys():
Expand Down
29 changes: 7 additions & 22 deletions helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,13 +1395,9 @@ def init_benchmark_base_model(self):
structured_data={"message": "Base model benchmark begins"},
message_type="init_benchmark_base_model_begin",
)
if is_lr_scheduler_disabled(self.config.optimizer):
self.optimizer.eval()
# we'll run validation on base model if it hasn't already.
self.validation.run_validations(validation_type="base_model", step=0)
self.validation.save_benchmark("base_model")
if is_lr_scheduler_disabled(self.config.optimizer):
self.optimizer.train()
self._send_webhook_raw(
structured_data={"message": "Base model benchmark completed"},
message_type="init_benchmark_base_model_completed",
Expand Down Expand Up @@ -2412,11 +2408,8 @@ def train(self):
# x-prediction requires that we now subtract the noise residual from the prediction to get the target sample.
if (
hasattr(self.noise_scheduler, "config")
and hasattr(
self.noise_scheduler.config, "prediction_type"
)
and self.noise_scheduler.config.prediction_type
== "sample"
and hasattr(self.noise_scheduler.config, "prediction_type")
and self.noise_scheduler.config.prediction_type == "sample"
):
model_pred = model_pred - noise

Expand All @@ -2428,10 +2421,7 @@ def train(self):
loss = (
model_pred.float() - target.float()
) ** 2 # Shape: (batch_size, C, H, W)
elif (
self.config.snr_gamma is None
or self.config.snr_gamma == 0
):
elif self.config.snr_gamma is None or self.config.snr_gamma == 0:
training_logger.debug("Calculating loss")
loss = self.config.snr_weight * F.mse_loss(
model_pred.float(), target.float(), reduction="none"
Expand All @@ -2448,8 +2438,7 @@ def train(self):
== "v_prediction"
or (
self.config.flow_matching
and self.config.flow_matching_loss
== "diffusion"
and self.config.flow_matching_loss == "diffusion"
)
):
snr_divisor = snr + 1
Expand All @@ -2461,8 +2450,7 @@ def train(self):
torch.stack(
[
snr,
self.config.snr_gamma
* torch.ones_like(timesteps),
self.config.snr_gamma * torch.ones_like(timesteps),
],
dim=1,
).min(dim=1)[0]
Expand All @@ -2478,9 +2466,7 @@ def train(self):
mse_loss_weights = mse_loss_weights.view(
-1, 1, 1, 1
) # Shape: (batch_size, 1, 1, 1)
loss = (
loss * mse_loss_weights
) # Shape: (batch_size, C, H, W)
loss = loss * mse_loss_weights # Shape: (batch_size, C, H, W)

# Mask the loss using any conditioning data
conditioning_type = batch.get("conditioning_type")
Expand Down Expand Up @@ -2514,8 +2500,7 @@ def train(self):
loss.repeat(self.config.train_batch_size)
).mean()
self.train_loss += (
avg_loss.item()
/ self.config.gradient_accumulation_steps
avg_loss.item() / self.config.gradient_accumulation_steps
)
# Backpropagate
grad_norm = None
Expand Down
Loading

0 comments on commit 71bea97

Please sign in to comment.