Skip to content

Commit

Permalink
Fix comment
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Nov 8, 2023
1 parent 6f7387a commit cefc921
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,12 @@ def _compute_loss_and_backward(self, output: Tensor, y: Tensor):
if isinstance(self._loss_func, MSELoss):
flat_logits = output.flatten(start_dim=1)
out_dims = flat_logits.shape[1]
# Accounts for the reduction used in the loss function.
scale = 1.0 / out_dims if reduction == "mean" else 1.0
for i in range(out_dims):
# Mean or sum reduction over all loss terms.
# Multiply by sqrt(scale * 2.0) since the MSELoss does
# not include the scale / 2.0 factor.
# not include the 1 / 2 factor.
loss_i = sqrt(scale * 2.0) * reduction_fn(flat_logits[:, i])
loss_i.backward(retain_graph=i < out_dims - 1)
elif isinstance(self._loss_func, CrossEntropyLoss):
Expand Down

0 comments on commit cefc921

Please sign in to comment.