Skip to content

Commit

Permalink
[ADD] Support CrossEntropyLoss in KFAC (#52)
Browse files Browse the repository at this point in the history
* [ADD] Prototype for KFAC linear operator

* [DOC] Progress on documentation

* [DOC] Describe KFAC and its limitations

* [FIX] Name of fixture

* [FIX] Darglint

* [FIX] Darglint

See terrencepreilly/darglint#53

* [ADD] Support `CrossEntropyLoss` in KFAC
  • Loading branch information
f-dangel authored Oct 30, 2023
1 parent 4b37a8f commit b3d7463
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 7 deletions.
22 changes: 20 additions & 2 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
from math import sqrt
from typing import Dict, Iterable, List, Tuple, Union

from einops import rearrange
from numpy import ndarray
from torch import Generator, Tensor, einsum, randn
from torch.nn import Linear, Module, MSELoss, Parameter
from torch.nn import CrossEntropyLoss, Linear, Module, MSELoss, Parameter
from torch.utils.hooks import RemovableHandle

from curvlinops._base import _LinearOperator
Expand Down Expand Up @@ -66,7 +67,7 @@ class KFACLinearOperator(_LinearOperator):
_SUPPORTED_MODULES: Tuple of supported layers.
"""

_SUPPORTED_LOSSES = (MSELoss,)
_SUPPORTED_LOSSES = (MSELoss, CrossEntropyLoss)
_SUPPORTED_MODULES = (Linear,)

def __init__(
Expand Down Expand Up @@ -276,6 +277,23 @@ def draw_label(self, output: Tensor) -> Tensor:
generator=self._generator,
)
return output.clone().detach() + perturbation

elif isinstance(self._loss_func, CrossEntropyLoss):
# TODO For output.ndim > 2, the scale of the 'would-be' gradient resulting
# from these labels might be off
if output.ndim != 2:
raise NotImplementedError(
"Only 2D output is supported for CrossEntropyLoss for now."
)
probs = output.softmax(dim=1)
# each row contains a vector describing a categorical
probs_as_mat = rearrange(probs, "n c ... -> (n ...) c")
labels = probs_as_mat.multinomial(
num_samples=1, generator=self._generator
).squeeze(-1)
label_shape = output.shape[:1] + output.shape[2:]
return labels.reshape(label_shape)

else:
raise NotImplementedError

Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ install_requires =
torch>=2.0
scipy>=1.7.1,<2.0.0
tqdm>=4.61.0,<5.0.0
einops
# Require a specific Python version, e.g. Python 2.7 or >= 3.4
python_requires = >=3.8

Expand Down
20 changes: 17 additions & 3 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Contains pytest fixtures that are visible by other files."""

from test.cases import ADJOINT_CASES, CASES, NON_DETERMINISTIC_CASES
from test.kfac_cases import KFAC_EXPAND_EXACT_CASES
from test.kfac_cases import KFAC_EXPAND_EXACT_CASES, KFAC_EXPAND_EXACT_ONE_DATUM_CASES
from typing import Callable, Dict, Iterable, List, Tuple

from numpy import random
Expand Down Expand Up @@ -70,5 +70,19 @@ def kfac_expand_exact_case(
A neural network, the mean-squared error function, a list of parameters, and
a data set.
"""
kfac_case = request.param
yield initialize_case(kfac_case)
case = request.param
yield initialize_case(case)


@fixture(params=KFAC_EXPAND_EXACT_ONE_DATUM_CASES)
def kfac_expand_exact_one_datum_case(
request,
) -> Tuple[Module, Module, List[Tensor], Iterable[Tuple[Tensor, Tensor]],]:
"""Prepare a test case for which KFAC-expand equals the GGN and one datum is used.
Yields:
A neural network, loss function, a list of parameters, and
a data set with a single datum.
"""
case = request.param
yield initialize_case(case)
34 changes: 32 additions & 2 deletions test/kfac_cases.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Contains test cases for the KFAC linear operator."""

from functools import partial
from test.utils import get_available_devices, regression_targets
from test.utils import classification_targets, get_available_devices, regression_targets

from torch import rand
from torch.nn import Linear, MSELoss, Sequential
from torch.nn import CrossEntropyLoss, Linear, MSELoss, Sequential

# Add test cases here, devices and loss function with different reductions will be
# added automatically below
Expand Down Expand Up @@ -42,3 +42,33 @@
"loss_func": partial(MSELoss, reduction=reduction),
}
KFAC_EXPAND_EXACT_CASES.append(case_with_device_and_loss_func)


# Add test cases here, devices will be added automatically below
KFAC_EXPAND_EXACT_ONE_DATUM_CASES_NO_DEVICE = [
###############################################################################
# CLASSIFICATION #
###############################################################################
# deep linear network with vector output (both reductions)
{
"model_func": lambda: Sequential(Linear(5, 4), Linear(4, 3)),
"loss_func": lambda: CrossEntropyLoss(reduction="mean"),
"data": lambda: [(rand(1, 5), classification_targets((1,), 3))],
"seed": 0,
},
{
"model_func": lambda: Sequential(Linear(5, 4), Linear(4, 3)),
"loss_func": lambda: CrossEntropyLoss(reduction="sum"),
"data": lambda: [(rand(1, 5), classification_targets((1,), 3))],
"seed": 0,
},
]

KFAC_EXPAND_EXACT_ONE_DATUM_CASES = []
for case in KFAC_EXPAND_EXACT_ONE_DATUM_CASES_NO_DEVICE:
for device in get_available_devices():
case_with_device = {
**case,
"device": device,
}
KFAC_EXPAND_EXACT_ONE_DATUM_CASES.append(case_with_device)
22 changes: 22 additions & 0 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,25 @@ def test_kfac(
rtol = {"sum": 2e-2, "mean": 2e-2}[loss_func.reduction]

report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol)


def test_kfac_one_datum(
kfac_expand_exact_one_datum_case: Tuple[
Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
]
):
model, loss_func, params, data = kfac_expand_exact_one_datum_case

ggn_blocks = [] # list of per-parameter GGNs
for param in params:
ggn = GGNLinearOperator(model, loss_func, [param], data)
ggn_blocks.append(ggn @ eye(ggn.shape[1]))
ggn = block_diag(*ggn_blocks)

kfac = KFACLinearOperator(model, loss_func, params, data, mc_samples=10_000)
kfac_mat = kfac @ eye(kfac.shape[1])

atol = {"sum": 1e-3, "mean": 1e-3}[loss_func.reduction]
rtol = {"sum": 3e-2, "mean": 3e-2}[loss_func.reduction]

report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol)

0 comments on commit b3d7463

Please sign in to comment.