Skip to content

Commit

Permalink
Fix scaling of MC Fisher for MSELoss and BCEWithLogitsLoss with mean …
Browse files Browse the repository at this point in the history
…reduction (#112)
  • Loading branch information
runame authored May 11, 2024
1 parent fa502d2 commit c8e9204
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions curvlinops/fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _matmat_batch(

grad_output = self.sample_grad_output(output, self._mc_samples, y)

# Adjust the scale depending on the loss reduction used
# Adjust the scale depending on the loss function and reduction used
num_loss_terms, C = output.shape
reduction_factor = {
"mean": (
Expand Down Expand Up @@ -266,7 +266,12 @@ def sample_grad_output(self, output: Tensor, num_samples: int, y: Tensor) -> Ten
For a single data point, the would-be gradient's outer product equals the
Hessian ``∇²_f log p(·|f)`` in expectation.
Currently only supports ``MSELoss`` and ``CrossEntropyLoss``.
Currently only supports ``MSELoss``, ``CrossEntropyLoss``, and
``BCEWithLogitsLoss``.
The returned gradient does not account for the scaling of the loss function by
the output dimension ``C`` that ``MSELoss`` and ``BCEWithLogitsLoss`` apply when
``reduction='mean'``.
Args:
output: model prediction ``f`` for multiple data with batch axis as
Expand All @@ -289,10 +294,7 @@ def sample_grad_output(self, output: Tensor, num_samples: int, y: Tensor) -> Ten
C = output.shape[1]

if isinstance(self._loss_func, MSELoss):
std = as_tensor(
{"mean": sqrt(0.5 / C), "sum": sqrt(0.5)}[self._loss_func.reduction],
device=output.device,
)
std = as_tensor(sqrt(0.5), device=output.device)
mean = zeros(
num_samples, *output.shape, device=output.device, dtype=output.dtype
)
Expand Down Expand Up @@ -320,12 +322,7 @@ def sample_grad_output(self, output: Tensor, num_samples: int, y: Tensor) -> Ten
# repeat ``num_sample`` times along a new leading axis
prob = prob.unsqueeze(0).expand(num_samples, -1, -1)
sample = prob.bernoulli(generator=self._generator)

# With ``reduction="mean"``, BCEWithLogitsLoss averages over all
# dimensions, like ``MSELoss``. We need to incorporate this scaling
# into the backpropagated gradient
scale = {"sum": 1.0, "mean": sqrt(1.0 / C)}[self._loss_func.reduction]
return (prob - sample) * scale
return prob - sample

else:
raise NotImplementedError(f"Supported losses: {self.supported_losses}")
Expand Down

0 comments on commit c8e9204

Please sign in to comment.