Skip to content

Commit

Permalink
(#990) do not call optimiser eval/train func when benchmarking since …
Browse files Browse the repository at this point in the history
…it is early enough not to be needed
  • Loading branch information
bghira committed Oct 20, 2024
1 parent a9f3eb8 commit ddc7789
Showing 1 changed file with 7 additions and 22 deletions.
29 changes: 7 additions & 22 deletions helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,13 +1395,9 @@ def init_benchmark_base_model(self):
structured_data={"message": "Base model benchmark begins"},
message_type="init_benchmark_base_model_begin",
)
if is_lr_scheduler_disabled(self.config.optimizer):
self.optimizer.eval()
# we'll run validation on base model if it hasn't already.
self.validation.run_validations(validation_type="base_model", step=0)
self.validation.save_benchmark("base_model")
if is_lr_scheduler_disabled(self.config.optimizer):
self.optimizer.train()
self._send_webhook_raw(
structured_data={"message": "Base model benchmark completed"},
message_type="init_benchmark_base_model_completed",
Expand Down Expand Up @@ -2412,11 +2408,8 @@ def train(self):
# x-prediction requires that we now subtract the noise residual from the prediction to get the target sample.
if (
hasattr(self.noise_scheduler, "config")
and hasattr(
self.noise_scheduler.config, "prediction_type"
)
and self.noise_scheduler.config.prediction_type
== "sample"
and hasattr(self.noise_scheduler.config, "prediction_type")
and self.noise_scheduler.config.prediction_type == "sample"
):
model_pred = model_pred - noise

Expand All @@ -2428,10 +2421,7 @@ def train(self):
loss = (
model_pred.float() - target.float()
) ** 2 # Shape: (batch_size, C, H, W)
elif (
self.config.snr_gamma is None
or self.config.snr_gamma == 0
):
elif self.config.snr_gamma is None or self.config.snr_gamma == 0:
training_logger.debug("Calculating loss")
loss = self.config.snr_weight * F.mse_loss(
model_pred.float(), target.float(), reduction="none"
Expand All @@ -2448,8 +2438,7 @@ def train(self):
== "v_prediction"
or (
self.config.flow_matching
and self.config.flow_matching_loss
== "diffusion"
and self.config.flow_matching_loss == "diffusion"
)
):
snr_divisor = snr + 1
Expand All @@ -2461,8 +2450,7 @@ def train(self):
torch.stack(
[
snr,
self.config.snr_gamma
* torch.ones_like(timesteps),
self.config.snr_gamma * torch.ones_like(timesteps),
],
dim=1,
).min(dim=1)[0]
Expand All @@ -2478,9 +2466,7 @@ def train(self):
mse_loss_weights = mse_loss_weights.view(
-1, 1, 1, 1
) # Shape: (batch_size, 1, 1, 1)
loss = (
loss * mse_loss_weights
) # Shape: (batch_size, C, H, W)
loss = loss * mse_loss_weights # Shape: (batch_size, C, H, W)

# Mask the loss using any conditioning data
conditioning_type = batch.get("conditioning_type")
Expand Down Expand Up @@ -2514,8 +2500,7 @@ def train(self):
loss.repeat(self.config.train_batch_size)
).mean()
self.train_loss += (
avg_loss.item()
/ self.config.gradient_accumulation_steps
avg_loss.item() / self.config.gradient_accumulation_steps
)
# Backpropagate
grad_norm = None
Expand Down

0 comments on commit ddc7789

Please sign in to comment.