Skip to content

Commit

Permalink
Implement type-2 KFAC for MSELoss
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Nov 7, 2023
1 parent 17f8c68 commit 756db7a
Showing 1 changed file with 45 additions and 21 deletions.
66 changes: 45 additions & 21 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

from einops import rearrange
from numpy import ndarray
from torch import Generator, Tensor, einsum, randn
from torch import Generator, Tensor, einsum
from torch import mean as torch_mean
from torch import randn
from torch import sum as torch_sum
from torch.nn import CrossEntropyLoss, Linear, Module, MSELoss, Parameter
from torch.utils.hooks import RemovableHandle

Expand Down Expand Up @@ -125,7 +128,7 @@ def __init__(
used which corresponds to the uncentered gradient covariance, or
the empirical Fisher. Defaults to ``'mc'``.
mc_samples: The number of Monte-Carlo samples to use per data point.
Will be ignored when ``fisher_type`` is not ``'mc'``.
Has to be set to ``1`` when ``fisher_type != 'mc'``.
Defaults to ``1``.
Raises:
Expand All @@ -138,6 +141,11 @@ def __init__(
raise ValueError(
f"Invalid loss: {loss_func}. Supported: {self._SUPPORTED_LOSSES}."
)
if fisher_type != "mc" and mc_samples != 1:
raise ValueError(
f"Invalid mc_samples: {mc_samples}. "
"Only mc_samples=1 is supported for fisher_type != 'mc'."
)

self.param_ids = [p.data_ptr() for p in params]
self.hooked_modules: List[str] = []
Expand Down Expand Up @@ -251,31 +259,46 @@ def _compute_kfac(self):

for X, y in self._loop_over_data(desc="KFAC matrices"):
output = self._model_func(X)

if self._fisher_type == "type-2":
raise NotImplementedError(
"Using the exact expectation for computing the KFAC "
"approximation of the Fisher is not yet supported."
)
elif self._fisher_type == "mc":
for mc in range(self._mc_samples):
y_sampled = self.draw_label(output)
loss = self._loss_func(output, y_sampled)
loss.backward(retain_graph=mc != self._mc_samples - 1)
elif self._fisher_type == "empirical":
loss = self._loss_func(output, y)
loss.backward()
else:
raise ValueError(
f"Invalid fisher_type: {self._fisher_type}. "
+ "Supported: 'type-2', 'mc', 'empirical'."
)
self._compute_loss_and_backward(output, y)

# clean up
self._model_func.zero_grad()
for handle in hook_handles:
handle.remove()

def _compute_loss_and_backward(self, output: Tensor, y: Tensor):
"""Compute the loss and the backward pass(es) required for KFAC."""
if self._fisher_type == "type-2":
reduction = self._loss_func.reduction
reduction_fn = {"sum": torch_sum, "mean": torch_mean}[reduction]
if isinstance(self._loss_func, MSELoss):
flat_logits = output.flatten(start_dim=1)
out_dims = flat_logits.shape[1]
scale = 1.0 / out_dims if reduction == "mean" else 1.0
for i in range(out_dims):
# Mean or sum reduction over all loss terms.
# Multiply by sqrt(scale * 2.0) since the MSELoss does
# not include the scale / 2.0 factor.
loss_i = sqrt(scale * 2.0) * reduction_fn(flat_logits[:, i])
loss_i.backward(retain_graph=i < out_dims - 1)
elif isinstance(self._loss_func, CrossEntropyLoss):
raise NotImplementedError(
"type-2 KFAC Fisher not yet implemented for CrossEntropyLoss."
)
elif self._fisher_type == "mc":
for mc in range(self._mc_samples):
y_sampled = self.draw_label(output)
loss = self._loss_func(output, y_sampled)
loss.backward(retain_graph=mc != self._mc_samples - 1)
elif self._fisher_type == "empirical":
loss = self._loss_func(output, y)
loss.backward()
else:
raise ValueError(
f"Invalid fisher_type: {self._fisher_type}. "
+ "Supported: 'type-2', 'mc', 'empirical'."
)

def draw_label(self, output: Tensor) -> Tensor:
r"""Draw a sample from the model's predictive distribution.
Expand Down Expand Up @@ -361,6 +384,7 @@ def _hook_accumulate_gradient_covariance(
)

batch_size = g.shape[0]
# self._mc_samples will be 1 if fisher_type != "mc"
correction = {
"sum": 1.0 / self._mc_samples,
"mean": batch_size**2 / (self._N_data * self._mc_samples),
Expand Down

0 comments on commit 756db7a

Please sign in to comment.