Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] Support CrossEntropyLoss in KFAC #52

Merged
merged 8 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
from math import sqrt
from typing import Dict, Iterable, List, Tuple, Union

from einops import rearrange
from numpy import ndarray
from torch import Generator, Tensor, einsum, randn
from torch.nn import Linear, Module, MSELoss, Parameter
from torch.nn import CrossEntropyLoss, Linear, Module, MSELoss, Parameter
from torch.utils.hooks import RemovableHandle

from curvlinops._base import _LinearOperator
Expand Down Expand Up @@ -66,7 +67,7 @@ class KFACLinearOperator(_LinearOperator):
_SUPPORTED_MODULES: Tuple of supported layers.
"""

_SUPPORTED_LOSSES = (MSELoss,)
_SUPPORTED_LOSSES = (MSELoss, CrossEntropyLoss)
_SUPPORTED_MODULES = (Linear,)

def __init__(
Expand Down Expand Up @@ -279,6 +280,23 @@ def draw_label(self, output: Tensor) -> Tensor:
generator=self._generator,
)
return output.clone().detach() + perturbation

elif isinstance(self._loss_func, CrossEntropyLoss):
# TODO For output.ndim > 2, the scale of the 'would-be' gradient resulting
# from these labels might be off
if output.ndim != 2:
raise NotImplementedError(
"Only 2D output is supported for CrossEntropyLoss for now."
)
probs = output.softmax(dim=1)
# each row contains a vector describing a categorical
probs_as_mat = rearrange(probs, "n c ... -> (n ...) c")
f-dangel marked this conversation as resolved.
Show resolved Hide resolved
labels = probs_as_mat.multinomial(
num_samples=1, generator=self._generator
).squeeze(-1)
label_shape = output.shape[:1] + output.shape[2:]
return labels.reshape(label_shape)

else:
raise NotImplementedError

Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ install_requires =
torch>=2.0
scipy>=1.7.1,<2.0.0
tqdm>=4.61.0,<5.0.0
einops
# Require a specific Python version, e.g. Python 2.7 or >= 3.4
python_requires = >=3.8

Expand Down
20 changes: 17 additions & 3 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Contains pytest fixtures that are visible by other files."""

from test.cases import ADJOINT_CASES, CASES, NON_DETERMINISTIC_CASES
from test.kfac_cases import KFAC_EXPAND_EXACT_CASES
from test.kfac_cases import KFAC_EXPAND_EXACT_CASES, KFAC_EXPAND_EXACT_ONE_DATUM_CASES
from typing import Callable, Dict, Iterable, List, Tuple

from numpy import random
Expand Down Expand Up @@ -70,5 +70,19 @@ def kfac_expand_exact_case(
A neural network, the mean-squared error function, a list of parameters, and
a data set.
"""
kfac_case = request.param
yield initialize_case(kfac_case)
case = request.param
yield initialize_case(case)


@fixture(params=KFAC_EXPAND_EXACT_ONE_DATUM_CASES)
def kfac_expand_exact_one_datum_case(
request,
) -> Tuple[Module, Module, List[Tensor], Iterable[Tuple[Tensor, Tensor]],]:
"""Prepare a test case for which KFAC-expand equals the GGN and one datum is used.

Yields:
A neural network, loss function, a list of parameters, and
a data set with a single datum.
"""
case = request.param
yield initialize_case(case)
34 changes: 32 additions & 2 deletions test/kfac_cases.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Contains test cases for the KFAC linear operator."""

from functools import partial
from test.utils import get_available_devices, regression_targets
from test.utils import classification_targets, get_available_devices, regression_targets

from torch import rand
from torch.nn import Linear, MSELoss, Sequential
from torch.nn import CrossEntropyLoss, Linear, MSELoss, Sequential

# Add test cases here, devices and loss function with different reductions will be
# added automatically below
Expand Down Expand Up @@ -42,3 +42,33 @@
"loss_func": partial(MSELoss, reduction=reduction),
}
KFAC_EXPAND_EXACT_CASES.append(case_with_device_and_loss_func)


# Add test cases here, devices will be added automatically below
KFAC_EXPAND_EXACT_ONE_DATUM_CASES_NO_DEVICE = [
###############################################################################
# CLASSIFICATION #
###############################################################################
# deep linear network with vector output (both reductions)
{
"model_func": lambda: Sequential(Linear(5, 4), Linear(4, 3)),
"loss_func": lambda: CrossEntropyLoss(reduction="mean"),
"data": lambda: [(rand(1, 5), classification_targets((1,), 3))],
"seed": 0,
},
{
"model_func": lambda: Sequential(Linear(5, 4), Linear(4, 3)),
"loss_func": lambda: CrossEntropyLoss(reduction="sum"),
"data": lambda: [(rand(1, 5), classification_targets((1,), 3))],
"seed": 0,
},
]

KFAC_EXPAND_EXACT_ONE_DATUM_CASES = []
for case in KFAC_EXPAND_EXACT_ONE_DATUM_CASES_NO_DEVICE:
for device in get_available_devices():
case_with_device = {
**case,
"device": device,
}
KFAC_EXPAND_EXACT_ONE_DATUM_CASES.append(case_with_device)
22 changes: 22 additions & 0 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,25 @@ def test_kfac(
rtol = {"sum": 2e-2, "mean": 2e-2}[loss_func.reduction]

report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol)


def test_kfac_one_datum(
kfac_expand_exact_one_datum_case: Tuple[
Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
]
):
model, loss_func, params, data = kfac_expand_exact_one_datum_case

ggn_blocks = [] # list of per-parameter GGNs
for param in params:
ggn = GGNLinearOperator(model, loss_func, [param], data)
ggn_blocks.append(ggn @ eye(ggn.shape[1]))
ggn = block_diag(*ggn_blocks)

kfac = KFACLinearOperator(model, loss_func, params, data, mc_samples=10_000)
kfac_mat = kfac @ eye(kfac.shape[1])

atol = {"sum": 1e-3, "mean": 1e-3}[loss_func.reduction]
rtol = {"sum": 3e-2, "mean": 3e-2}[loss_func.reduction]

report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol)
Loading