Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Scaling of MC Fisher for MSELoss and BCEWithLogitsLoss with mean reduction #112

Merged
merged 1 commit into from
May 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions curvlinops/fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _matmat_batch(

grad_output = self.sample_grad_output(output, self._mc_samples, y)

# Adjust the scale depending on the loss reduction used
# Adjust the scale depending on the loss function and reduction used
num_loss_terms, C = output.shape
reduction_factor = {
"mean": (
Expand Down Expand Up @@ -266,7 +266,12 @@ def sample_grad_output(self, output: Tensor, num_samples: int, y: Tensor) -> Ten
For a single data point, the would-be gradient's outer product equals the
Hessian ``∇²_f log p(·|f)`` in expectation.

Currently only supports ``MSELoss`` and ``CrossEntropyLoss``.
Currently only supports ``MSELoss``, ``CrossEntropyLoss``, and
``BCEWithLogitsLoss``.

The returned gradient does not account for the scaling of the loss function by
the output dimension ``C`` that ``MSELoss`` and ``BCEWithLogitsLoss`` apply when
``reduction='mean'``.

Args:
output: model prediction ``f`` for multiple data with batch axis as
Expand All @@ -289,10 +294,7 @@ def sample_grad_output(self, output: Tensor, num_samples: int, y: Tensor) -> Ten
C = output.shape[1]

if isinstance(self._loss_func, MSELoss):
std = as_tensor(
{"mean": sqrt(0.5 / C), "sum": sqrt(0.5)}[self._loss_func.reduction],
device=output.device,
)
std = as_tensor(sqrt(0.5), device=output.device)
mean = zeros(
num_samples, *output.shape, device=output.device, dtype=output.dtype
)
Expand Down Expand Up @@ -320,12 +322,7 @@ def sample_grad_output(self, output: Tensor, num_samples: int, y: Tensor) -> Ten
# repeat ``num_sample`` times along a new leading axis
prob = prob.unsqueeze(0).expand(num_samples, -1, -1)
sample = prob.bernoulli(generator=self._generator)

# With ``reduction="mean"``, BCEWithLogitsLoss averages over all
# dimensions, like ``MSELoss``. We need to incorporate this scaling
# into the backpropagated gradient
scale = {"sum": 1.0, "mean": sqrt(1.0 / C)}[self._loss_func.reduction]
return (prob - sample) * scale
return prob - sample

else:
raise NotImplementedError(f"Supported losses: {self.supported_losses}")
Expand Down
Loading