diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 10443f8..aeb7550 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -16,9 +16,10 @@ from math import sqrt from typing import Dict, Iterable, List, Tuple, Union +from einops import rearrange from numpy import ndarray from torch import Generator, Tensor, einsum, randn -from torch.nn import Linear, Module, MSELoss, Parameter +from torch.nn import CrossEntropyLoss, Linear, Module, MSELoss, Parameter from torch.utils.hooks import RemovableHandle from curvlinops._base import _LinearOperator @@ -66,7 +67,7 @@ class KFACLinearOperator(_LinearOperator): _SUPPORTED_MODULES: Tuple of supported layers. """ - _SUPPORTED_LOSSES = (MSELoss,) + _SUPPORTED_LOSSES = (MSELoss, CrossEntropyLoss) _SUPPORTED_MODULES = (Linear,) def __init__( @@ -276,6 +277,23 @@ def draw_label(self, output: Tensor) -> Tensor: generator=self._generator, ) return output.clone().detach() + perturbation + + elif isinstance(self._loss_func, CrossEntropyLoss): + # TODO For output.ndim > 2, the scale of the 'would-be' gradient resulting + # from these labels might be off + if output.ndim != 2: + raise NotImplementedError( + "Only 2D output is supported for CrossEntropyLoss for now." + ) + probs = output.softmax(dim=1) + # each row contains a vector describing a categorical + probs_as_mat = rearrange(probs, "n c ... -> (n ...) c") + labels = probs_as_mat.multinomial( + num_samples=1, generator=self._generator + ).squeeze(-1) + label_shape = output.shape[:1] + output.shape[2:] + return labels.reshape(label_shape) + else: raise NotImplementedError diff --git a/setup.cfg b/setup.cfg index 21001df..880d07c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,6 +39,7 @@ install_requires = torch>=2.0 scipy>=1.7.1,<2.0.0 tqdm>=4.61.0,<5.0.0 + einops # Require a specific Python version, e.g. Python 2.7 or >= 3.4 python_requires = >=3.8 diff --git a/test/conftest.py b/test/conftest.py index b3f4726..10f0617 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,7 +1,7 @@ """Contains pytest fixtures that are visible by other files.""" from test.cases import ADJOINT_CASES, CASES, NON_DETERMINISTIC_CASES -from test.kfac_cases import KFAC_EXPAND_EXACT_CASES +from test.kfac_cases import KFAC_EXPAND_EXACT_CASES, KFAC_EXPAND_EXACT_ONE_DATUM_CASES from typing import Callable, Dict, Iterable, List, Tuple from numpy import random @@ -70,5 +70,19 @@ def kfac_expand_exact_case( A neural network, the mean-squared error function, a list of parameters, and a data set. """ - kfac_case = request.param - yield initialize_case(kfac_case) + case = request.param + yield initialize_case(case) + + +@fixture(params=KFAC_EXPAND_EXACT_ONE_DATUM_CASES) +def kfac_expand_exact_one_datum_case( + request, +) -> Tuple[Module, Module, List[Tensor], Iterable[Tuple[Tensor, Tensor]],]: + """Prepare a test case for which KFAC-expand equals the GGN and one datum is used. + + Yields: + A neural network, loss function, a list of parameters, and + a data set with a single datum. + """ + case = request.param + yield initialize_case(case) diff --git a/test/kfac_cases.py b/test/kfac_cases.py index 4b96f3c..33a7669 100644 --- a/test/kfac_cases.py +++ b/test/kfac_cases.py @@ -1,10 +1,10 @@ """Contains test cases for the KFAC linear operator.""" from functools import partial -from test.utils import get_available_devices, regression_targets +from test.utils import classification_targets, get_available_devices, regression_targets from torch import rand -from torch.nn import Linear, MSELoss, Sequential +from torch.nn import CrossEntropyLoss, Linear, MSELoss, Sequential # Add test cases here, devices and loss function with different reductions will be # added automatically below @@ -42,3 +42,33 @@ "loss_func": partial(MSELoss, reduction=reduction), } KFAC_EXPAND_EXACT_CASES.append(case_with_device_and_loss_func) + + +# Add test cases here, devices will be added automatically below +KFAC_EXPAND_EXACT_ONE_DATUM_CASES_NO_DEVICE = [ + ############################################################################### + # CLASSIFICATION # + ############################################################################### + # deep linear network with vector output (both reductions) + { + "model_func": lambda: Sequential(Linear(5, 4), Linear(4, 3)), + "loss_func": lambda: CrossEntropyLoss(reduction="mean"), + "data": lambda: [(rand(1, 5), classification_targets((1,), 3))], + "seed": 0, + }, + { + "model_func": lambda: Sequential(Linear(5, 4), Linear(4, 3)), + "loss_func": lambda: CrossEntropyLoss(reduction="sum"), + "data": lambda: [(rand(1, 5), classification_targets((1,), 3))], + "seed": 0, + }, +] + +KFAC_EXPAND_EXACT_ONE_DATUM_CASES = [] +for case in KFAC_EXPAND_EXACT_ONE_DATUM_CASES_NO_DEVICE: + for device in get_available_devices(): + case_with_device = { + **case, + "device": device, + } + KFAC_EXPAND_EXACT_ONE_DATUM_CASES.append(case_with_device) diff --git a/test/test_kfac.py b/test/test_kfac.py index 365ce3e..b384646 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -46,3 +46,25 @@ def test_kfac( rtol = {"sum": 2e-2, "mean": 2e-2}[loss_func.reduction] report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol) + + +def test_kfac_one_datum( + kfac_expand_exact_one_datum_case: Tuple[ + Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]] + ] +): + model, loss_func, params, data = kfac_expand_exact_one_datum_case + + ggn_blocks = [] # list of per-parameter GGNs + for param in params: + ggn = GGNLinearOperator(model, loss_func, [param], data) + ggn_blocks.append(ggn @ eye(ggn.shape[1])) + ggn = block_diag(*ggn_blocks) + + kfac = KFACLinearOperator(model, loss_func, params, data, mc_samples=10_000) + kfac_mat = kfac @ eye(kfac.shape[1]) + + atol = {"sum": 1e-3, "mean": 1e-3}[loss_func.reduction] + rtol = {"sum": 3e-2, "mean": 3e-2}[loss_func.reduction] + + report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol)