Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Scaling of empirical/MC Fisher for output with more than two dimensions #109

Merged
merged 15 commits into from
May 5, 2024
Merged
21 changes: 19 additions & 2 deletions curvlinops/examples/functorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
from einops import rearrange
from torch import Tensor, cat, einsum
from torch.func import functional_call, grad, hessian, jacrev, jvp, vmap
from torch.nn import Module
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss


def blocks_to_matrix(blocks: Dict[str, Dict[str, Tensor]]) -> Tensor:
Expand Down Expand Up @@ -239,20 +240,36 @@ def loss_n(

params_argnum = 2
batch_grad_fn = vmap(grad(loss_n, argnums=params_argnum))

# If >2d input we convert to an equivalent 2d input; assumes model with only linear
# modules, i.e. the additional input dimensions are maintained until the output.
if isinstance(X, dict):
X[input_key] = rearrange(X[input_key], "batch ... d_in -> (batch ...) d_in")
else: # X is a tensor
X = rearrange(X, "batch ... d_in -> (batch ...) d_in")
N = X.shape[0] if batch_size_fn is None else batch_size_fn(X)

# If >2d label we convert to an equivalent 2d label
if isinstance(loss_func, CrossEntropyLoss):
y = rearrange(y, "batch ... -> (batch ...)")
else:
y = rearrange(y, "batch ... c -> (batch ...) c")

params_replicated_dict = {
name: p.unsqueeze(0).expand(N, *(p.dim() * [-1]))
for name, p in params_dict.items()
}

batch_grad = batch_grad_fn(X, y, params_replicated_dict)
batch_grad = cat([bg.flatten(start_dim=1) for bg in batch_grad.values()], dim=1)
assert batch_grad.shape == (N, sum(p.numel() for p in params))

if loss_func.reduction == "sum":
normalization = 1
elif loss_func.reduction == "mean":
normalization = N
if isinstance(loss_func, (MSELoss, BCEWithLogitsLoss)):
_, C = y.shape
batch_grad *= sqrt(C)
else:
raise ValueError("Cannot detect reduction method from loss function.")

Expand Down
26 changes: 21 additions & 5 deletions curvlinops/fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ def __init__(
entry of the iterates from ``data`` and return their batch size.

Raises:
NotImplementedError: If the loss function differs from ``MSELoss`` or
``CrossEntropyLoss``.
NotImplementedError: If the loss function differs from ``MSELoss``,
BCEWithLogitsLoss, or ``CrossEntropyLoss``.
"""
if not isinstance(loss_func, self.supported_losses):
raise NotImplementedError(
Expand Down Expand Up @@ -215,15 +215,31 @@ def _matmat_batch(
# compute ∂ℓₙ(yₙₘ)/∂fₙ where fₙ is the prediction for datum n and
# yₙₘ is the m-th sampled label for datum n
output = self._model_func(X)
# If >2d output we convert to an equivalent 2d output
if isinstance(self._loss_func, CrossEntropyLoss):
output = rearrange(output, "batch c ... -> (batch ...) c")
y = rearrange(y, "batch ... -> (batch ...)")
else:
output = rearrange(output, "batch ... c -> (batch ...) c")
y = rearrange(y, "batch ... c -> (batch ...) c")

grad_output = self.sample_grad_output(output, self._mc_samples, y)

# Adjust the scale depending on the loss reduction used
num_loss_terms, C = output.shape
reduction_factor = {
"mean": (
num_loss_terms
if isinstance(self._loss_func, CrossEntropyLoss)
else num_loss_terms * C
),
"sum": 1.0,
}[self._loss_func.reduction]

# Compute the pseudo-loss L' := 0.5 / (M * c) ∑ₙ ∑ₘ fₙᵀ (gₙₘ gₙₘᵀ) fₙ where
# gₙₘ = ∂ℓₙ(yₙₘ)/∂fₙ (detached) and M is the number of MC samples.
# The GGN of L' linearized at fₙ is the MC Fisher.
# We can thus multiply with it by computing the GGN-vector products of L'.
reduction_factor = {"mean": self._batch_size_fn(X), "sum": 1.0}[
self._loss_func.reduction
]
loss = (
0.5
/ reduction_factor
Expand Down
92 changes: 87 additions & 5 deletions curvlinops/gradient_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from __future__ import annotations

from collections.abc import MutableMapping
from typing import List, Tuple, Union
from typing import Callable, Iterable, List, Optional, Tuple, Union

from backpack.hessianfree.ggnvp import ggn_vector_product_from_plist
from einops import einsum
from einops import einsum, rearrange
from torch import Tensor, zeros_like
from torch.autograd import grad
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss, Parameter

from curvlinops._base import _LinearOperator

Expand Down Expand Up @@ -45,6 +46,72 @@ class EFLinearOperator(_LinearOperator):
inefficient for-loop.
"""

supported_losses = (MSELoss, CrossEntropyLoss, BCEWithLogitsLoss)

def __init__(
self,
model_func: Callable[[Tensor], Tensor],
loss_func: Union[Callable[[Tensor, Tensor], Tensor], None],
params: List[Parameter],
data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]],
progressbar: bool = False,
check_deterministic: bool = True,
num_data: Optional[int] = None,
batch_size_fn: Optional[Callable[[MutableMapping], int]] = None,
):
"""Linear operator for the uncentered gradient covariance/empirical Fisher (EF).

Note:
f(X; θ) denotes a neural network, parameterized by θ, that maps a mini-batch
input X to predictions p. ℓ(p, y) maps the prediction to a loss, using the
mini-batch labels y.

Args:
model_func: A function that maps the mini-batch input X to predictions.
Could be a PyTorch module representing a neural network.
loss_func: Loss function criterion. Maps predictions and mini-batch labels
to a scalar value.
params: List of differentiable parameters used by the prediction function.
data: Source from which mini-batches can be drawn, for instance a list of
mini-batches ``[(X, y), ...]`` or a torch ``DataLoader``. Note that ``X``
could be a ``dict`` or ``UserDict``; this is useful for custom models.
In this case, you must (i) specify the ``batch_size_fn`` argument, and
(ii) take care of preprocessing like ``X.to(device)`` inside of your
``model.forward()`` function. Due to the sequential internal Monte-Carlo
sampling, batches must be presented in the same deterministic
order (no shuffling!).
progressbar: Show a progressbar during matrix-multiplication.
Default: ``False``.
check_deterministic: Probe that model and data are deterministic, i.e.
that the data does not use `drop_last` or data augmentation. Also, the
model's forward pass could depend on the order in which mini-batches
are presented (BatchNorm, Dropout). Default: ``True``. This is a
safeguard, only turn it off if you know what you are doing.
num_data: Number of data points. If ``None``, it is inferred from the data
at the cost of one traversal through the data loader.
batch_size_fn: If the ``X``'s in ``data`` are not ``torch.Tensor``, this
needs to be specified. The intended behavior is to consume the first
entry of the iterates from ``data`` and return their batch size.

Raises:
NotImplementedError: If the loss function differs from ``MSELoss``,
BCEWithLogitsLoss, or ``CrossEntropyLoss``.
"""
if not isinstance(loss_func, self.supported_losses):
raise NotImplementedError(
f"Loss must be one of {self.supported_losses}. Got: {loss_func}."
)
super().__init__(
model_func,
loss_func,
params,
data,
progressbar=progressbar,
check_deterministic=check_deterministic,
num_data=num_data,
batch_size_fn=batch_size_fn,
)

def _matmat_batch(
self, X: Union[Tensor, MutableMapping], y: Tensor, M_list: List[Tensor]
) -> Tuple[Tensor, ...]:
Expand All @@ -63,9 +130,24 @@ def _matmat_batch(
leading dimension of matrix columns.
"""
output = self._model_func(X)
reduction_factor = {"mean": self._batch_size_fn(X), "sum": 1.0}[
self._loss_func.reduction
]
# If >2d output we convert to an equivalent 2d output
if isinstance(self._loss_func, CrossEntropyLoss):
output = rearrange(output, "batch c ... -> (batch ...) c")
y = rearrange(y, "batch ... -> (batch ...)")
else:
output = rearrange(output, "batch ... c -> (batch ...) c")
y = rearrange(y, "batch ... c -> (batch ...) c")

# Adjust the scale depending on the loss reduction used
num_loss_terms, C = output.shape
reduction_factor = {
"mean": (
num_loss_terms
if isinstance(self._loss_func, CrossEntropyLoss)
else num_loss_terms * C
),
"sum": 1.0,
}[self._loss_func.reduction]

# compute ∂ℓₙ/∂fₙ without reduction factor of L
(grad_output,) = grad(self._loss_func(output, y), output)
Expand Down
42 changes: 28 additions & 14 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,30 @@ def _compute_kfac(self):
for handle in hook_handles:
handle.remove()

def _maybe_adjust_loss_scale(self, loss: Tensor, output: Tensor) -> Tensor:
"""Adjust the scale of the loss tensor if necessary.

The ``BCEWithLogitsLoss`` and ``MSELoss`` also average over the output dimension
in addition to the batch dimension. We adjust the scale of the loss to correct
for this.

Args:
loss: The loss tensor to adjust.
output: The model's output.

Returns:
The scaled loss tensor.
"""
if (
isinstance(self._loss_func, (BCEWithLogitsLoss, MSELoss))
and self._loss_func.reduction == "mean"
):
# ``BCEWithLogitsLoss`` and ``MSELoss`` also average over non-batch
# dimensions. We have to scale the loss to incorporate this scaling.
_, C = output.shape
loss *= sqrt(C)
return loss

def _compute_loss_and_backward(self, output: Tensor, y: Tensor):
r"""Compute the loss and the backward pass(es) required for KFAC.

Expand Down Expand Up @@ -554,9 +578,9 @@ def _compute_loss_and_backward(self, output: Tensor, y: Tensor):
)

# Fix scaling caused by the batch dimension
batch_size = output.shape[0]
num_loss_terms = output.shape[0]
reduction = self._loss_func.reduction
scale = {"sum": 1.0, "mean": 1.0 / batch_size}[reduction]
scale = {"sum": 1.0, "mean": 1.0 / num_loss_terms}[reduction]
hessian_sqrts.mul_(scale)

# For each column `c` of the matrix square root we need to backpropagate,
Expand All @@ -574,22 +598,12 @@ def _compute_loss_and_backward(self, output: Tensor, y: Tensor):
for mc in range(self._mc_samples):
y_sampled = self.draw_label(output)
loss = self._loss_func(output, y_sampled)

if (
isinstance(self._loss_func, (BCEWithLogitsLoss, MSELoss))
and self._loss_func.reduction == "mean"
):
# ``BCEWithLogitsLoss`` and ``MSELoss`` also average over non-batch
# dimensions. We have to scale the loss to incorporate this scaling
# as we cannot generally achieve it by incorporating it into the
# drawn sample.
_, C = output.shape
loss *= sqrt(C)

loss = self._maybe_adjust_loss_scale(loss, output)
grad(loss, self._params, retain_graph=mc != self._mc_samples - 1)

elif self._fisher_type == "empirical":
loss = self._loss_func(output, y)
loss = self._maybe_adjust_loss_scale(loss, output)
grad(loss, self._params)

elif self._fisher_type == "forward-only":
Expand Down
Loading
Loading