Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] Support for KFAC with type-2 Fisher #56

Merged
merged 18 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 64 additions & 28 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@

from einops import rearrange
from numpy import ndarray
from torch import Generator, Tensor, cat, einsum, randn
from torch import Generator, Tensor, cat, einsum
from torch import mean as torch_mean
from torch import no_grad, randn
from torch import sum as torch_sum
from torch.nn import CrossEntropyLoss, Linear, Module, MSELoss, Parameter
from torch.nn.functional import log_softmax, softmax
from torch.utils.hooks import RemovableHandle

from curvlinops._base import _LinearOperator
Expand Down Expand Up @@ -124,7 +128,7 @@ def __init__(
used which corresponds to the uncentered gradient covariance, or
the empirical Fisher. Defaults to ``'mc'``.
mc_samples: The number of Monte-Carlo samples to use per data point.
Will be ignored when ``fisher_type`` is not ``'mc'``.
Has to be set to ``1`` when ``fisher_type != 'mc'``.
Defaults to ``1``.
separate_weight_and_bias: Whether to treat weights and biases separately.
Defaults to ``True``.
Expand All @@ -137,6 +141,11 @@ def __init__(
raise ValueError(
f"Invalid loss: {loss_func}. Supported: {self._SUPPORTED_LOSSES}."
)
if fisher_type != "mc" and mc_samples != 1:
raise ValueError(
f"Invalid mc_samples: {mc_samples}. "
"Only mc_samples=1 is supported for fisher_type != 'mc'."
)

self.param_ids = [p.data_ptr() for p in params]
# mapping from tuples of parameter data pointers in a module to its name
Expand Down Expand Up @@ -230,13 +239,7 @@ def _adjoint(self) -> KFACLinearOperator:
return self

def _compute_kfac(self):
"""Compute and cache KFAC's Kronecker factors for future ``matvec``s.

Raises:
NotImplementedError: If ``fisher_type == 'type-2'``.
ValueError: If ``fisher_type`` is not ``'type-2'``, ``'mc'``, or
``'empirical'``.
"""
"""Compute and cache KFAC's Kronecker factors for future ``matvec``s."""
# install forward and backward hooks
hook_handles: List[RemovableHandle] = []

Expand Down Expand Up @@ -265,31 +268,63 @@ def _compute_kfac(self):

for X, y in self._loop_over_data(desc="KFAC matrices"):
output = self._model_func(X)

if self._fisher_type == "type-2":
raise NotImplementedError(
"Using the exact expectation for computing the KFAC "
"approximation of the Fisher is not yet supported."
)
elif self._fisher_type == "mc":
for mc in range(self._mc_samples):
y_sampled = self.draw_label(output)
loss = self._loss_func(output, y_sampled)
loss.backward(retain_graph=mc != self._mc_samples - 1)
elif self._fisher_type == "empirical":
loss = self._loss_func(output, y)
loss.backward()
else:
raise ValueError(
f"Invalid fisher_type: {self._fisher_type}. "
+ "Supported: 'type-2', 'mc', 'empirical'."
)
self._compute_loss_and_backward(output, y)

# clean up
self._model_func.zero_grad()
for handle in hook_handles:
handle.remove()

def _compute_loss_and_backward(self, output: Tensor, y: Tensor):
r"""Compute the loss and the backward pass(es) required for KFAC.

Args:
output: The model's prediction
:math:`\{f_\mathbf{\theta}(\mathbf{x}_n)\}_{n=1}^N`.
y: The labels :math:`\{\mathbf{y}_n\}_{n=1}^N`.

Raises:
ValueError: If ``fisher_type`` is not ``'type-2'``, ``'mc'``, or
``'empirical'``.
"""
if self._fisher_type == "type-2":
reduction = self._loss_func.reduction
reduction_fn = {"sum": torch_sum, "mean": torch_mean}[reduction]
if isinstance(self._loss_func, MSELoss):
flat_logits = output.flatten(start_dim=1)
out_dims = flat_logits.shape[1]
# Accounts for the reduction used in the loss function.
scale = 1.0 / out_dims if reduction == "mean" else 1.0
for i in range(out_dims):
# Mean or sum reduction over all loss terms.
# Multiply by sqrt(scale * 2.0) since the MSELoss does
# not include the 1 / 2 factor.
loss_i = sqrt(scale * 2.0) * reduction_fn(flat_logits[:, i])
loss_i.backward(retain_graph=i < out_dims - 1)
elif isinstance(self._loss_func, CrossEntropyLoss):
flat_logits = output.flatten(end_dim=-2)
log_probs = log_softmax(flat_logits, dim=-1)
with no_grad():
sqrt_probs = softmax(flat_logits, dim=-1).sqrt()
num_classes = log_probs.shape[1]
for c in range(num_classes):
# Mean or sum reduction over all loss terms.
loss_c = reduction_fn(-log_probs[:, c] * sqrt_probs[:, c])
loss_c.backward(retain_graph=c < num_classes - 1)
f-dangel marked this conversation as resolved.
Show resolved Hide resolved
elif self._fisher_type == "mc":
for mc in range(self._mc_samples):
y_sampled = self.draw_label(output)
loss = self._loss_func(output, y_sampled)
loss.backward(retain_graph=mc != self._mc_samples - 1)
elif self._fisher_type == "empirical":
loss = self._loss_func(output, y)
loss.backward()
else:
raise ValueError(
f"Invalid fisher_type: {self._fisher_type}. "
+ "Supported: 'type-2', 'mc', 'empirical'."
)

def draw_label(self, output: Tensor) -> Tensor:
r"""Draw a sample from the model's predictive distribution.

Expand Down Expand Up @@ -375,6 +410,7 @@ def _hook_accumulate_gradient_covariance(
)

batch_size = g.shape[0]
# self._mc_samples will be 1 if fisher_type != "mc"
correction = {
"sum": 1.0 / self._mc_samples,
"mean": batch_size**2 / (self._N_data * self._mc_samples),
Expand Down
14 changes: 0 additions & 14 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,3 @@ def kfac_expand_exact_one_datum_case(
"""
case = request.param
yield initialize_case(case)


@fixture(params=KFAC_EXPAND_EXACT_ONE_DATUM_CASES)
def kfac_ef_exact_one_datum_case(
request,
) -> Tuple[Module, MSELoss, List[Tensor], Iterable[Tuple[Tensor, Tensor]],]:
"""Prepare a test case with one datum for which KFAC with empirical gradients equals the EF.

Yields:
A neural network, the mean-squared error function, a list of parameters, and
a data set.
"""
case = request.param
yield initialize_case(case)
63 changes: 52 additions & 11 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pytest import mark
from scipy.linalg import block_diag
from torch import Tensor, randperm
from torch.nn import Module, MSELoss, Parameter
from torch.nn import CrossEntropyLoss, Module, MSELoss, Parameter

from curvlinops.examples.utils import report_nonclose
from curvlinops.gradient_moments import EFLinearOperator
Expand Down Expand Up @@ -58,30 +58,71 @@ def test_kfac(
data,
separate_weight_and_bias=separate_weight_and_bias,
)

kfac = KFACLinearOperator(
model,
loss_func,
params,
data,
mc_samples=2_000,
fisher_type="type-2",
separate_weight_and_bias=separate_weight_and_bias,
)
kfac_mat = kfac @ eye(kfac.shape[1])

atol = {"sum": 5e-1, "mean": 5e-3}[loss_func.reduction]
rtol = {"sum": 2e-2, "mean": 2e-2}[loss_func.reduction]

report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol)
report_nonclose(ggn, kfac_mat)

# Check that input covariances were not computed
if exclude == "weight":
assert len(kfac._input_covariances) == 0


@mark.parametrize("shuffle", [False, True], ids=["", "shuffled"])
def test_kfac_mc(
kfac_expand_exact_case: Tuple[
Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
],
shuffle: bool,
):
"""Test the KFAC implementation using MC samples against the exact GGN.

Args:
kfac_expand_exact_case: A fixture that returns a model, loss function, list of
parameters, and data.
shuffle: Whether to shuffle the parameters before computing the KFAC matrix.
"""
model, loss_func, params, data = kfac_expand_exact_case

if shuffle:
permutation = randperm(len(params))
params = [params[i] for i in permutation]

ggn = ggn_block_diagonal(model, loss_func, params, data)
kfac = KFACLinearOperator(model, loss_func, params, data, mc_samples=2_000)

kfac_mat = kfac @ eye(kfac.shape[1])

atol = {"sum": 5e-1, "mean": 5e-3}[loss_func.reduction]
rtol = {"sum": 2e-2, "mean": 2e-2}[loss_func.reduction]

report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol)


def test_kfac_one_datum(
kfac_expand_exact_one_datum_case: Tuple[
Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
Module, CrossEntropyLoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
]
):
model, loss_func, params, data = kfac_expand_exact_one_datum_case

ggn = ggn_block_diagonal(model, loss_func, params, data)
kfac = KFACLinearOperator(model, loss_func, params, data, fisher_type="type-2")
kfac_mat = kfac @ eye(kfac.shape[1])

report_nonclose(ggn, kfac_mat)


def test_kfac_mc_one_datum(
kfac_expand_exact_one_datum_case: Tuple[
Module, CrossEntropyLoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
]
):
model, loss_func, params, data = kfac_expand_exact_one_datum_case
Expand All @@ -97,11 +138,11 @@ def test_kfac_one_datum(


def test_kfac_ef_one_datum(
kfac_ef_exact_one_datum_case: Tuple[
Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
kfac_expand_exact_one_datum_case: Tuple[
Module, CrossEntropyLoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
]
):
model, loss_func, params, data = kfac_ef_exact_one_datum_case
model, loss_func, params, data = kfac_expand_exact_one_datum_case

ef_blocks = [] # list of per-parameter EFs
for param in params:
Expand Down
Loading