Skip to content

Commit

Permalink
[ADD] Support arbitrary order of parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Oct 30, 2023
1 parent 5190c9f commit 59d8098
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 20 deletions.
33 changes: 15 additions & 18 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def __init__(
Warning:
This is an early proto-type with many limitations:
- Parameters must be in the same order as the model's parameters.
- Only linear layers with bias are supported.
- Weights and biases are treated separately.
- No weight sharing is supported.
Expand All @@ -119,8 +118,6 @@ def __init__(
Raises:
ValueError: If the loss function is not supported.
NotImplementedError: If the parameters are not in the same order as the
model's parameters.
NotImplementedError: If the model contains bias-free linear layers.
NotImplementedError: If any parameter cannot be identified with a supported
layer.
Expand All @@ -130,27 +127,24 @@ def __init__(
f"Invalid loss: {loss_func}. Supported: {self._SUPPORTED_LOSSES}."
)

self.param_ids = [p.data_ptr() for p in params]
self.hooked_modules: List[str] = []
idx = 0
for name, mod in model_func.named_modules():
if isinstance(mod, self._SUPPORTED_MODULES):
# TODO Support bias-free layers
if mod.bias is None:
raise NotImplementedError(
"Bias-free linear layers are not yet supported."
)
# TODO Support arbitrary orders and sub-sets of parameters
if (
params[idx].data_ptr() != mod.weight.data_ptr()
or params[idx + 1].data_ptr() != mod.bias.data_ptr()
):
raise NotImplementedError(
"KFAC parameters must be in same order as model parameters "
+ "for now."
)
idx += 2
self.hooked_modules.append(name)
if idx != len(params):

# check that all parameters are in hooked modules
hooked_param_ids = {
p.data_ptr()
for mod in self.hooked_modules
for p in model_func.get_submodule(mod).parameters()
}
if set(self.param_ids) != hooked_param_ids:
raise NotImplementedError(
"Could not identify all parameters with supported layers."
)
Expand Down Expand Up @@ -188,9 +182,12 @@ def _matvec(self, x: ndarray) -> ndarray:
x_torch = super()._preprocess(x)
assert len(x_torch) % 2 == 0

for idx in range(len(x_torch) // 2):
idx_weight, idx_bias = 2 * idx, 2 * idx + 1
weight, bias = self._params[idx_weight], self._params[idx_bias]
for mod_name in self.hooked_modules:
mod = self._model_func.get_submodule(mod_name)
weight, bias = mod.weight, mod.bias

idx_weight = self.param_ids.index(weight.data_ptr())
idx_bias = self.param_ids.index(bias.data_ptr())
x_weight, x_bias = x_torch[idx_weight], x_torch[idx_bias]

aaT = self._input_covariances[(weight.data_ptr(), bias.data_ptr())]
Expand Down
18 changes: 16 additions & 2 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,36 @@
from typing import Iterable, List, Tuple

from numpy import eye
from pytest import mark
from scipy.linalg import block_diag
from torch import Tensor
from torch import Tensor, randperm
from torch.nn import Module, MSELoss, Parameter

from curvlinops.examples.utils import report_nonclose
from curvlinops.ggn import GGNLinearOperator
from curvlinops.kfac import KFACLinearOperator


@mark.parametrize("shuffle", [False, True], ids=["", "shuffled"])
def test_kfac(
kfac_expand_exact_case: Tuple[
Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
]
],
shuffle: bool,
):
"""Test the KFAC implementation against the exact GGN.
Args:
kfac_expand_exact_case: A fixture that returns a model, loss function, list of
parameters, and data.
shuffle: Whether to shuffle the parameters before computing the KFAC matrix.
"""
model, loss_func, params, data = kfac_expand_exact_case

if shuffle:
permutation = randperm(len(params))
params = [params[i] for i in permutation]

ggn_blocks = [] # list of per-parameter GGNs
for param in params:
ggn = GGNLinearOperator(model, loss_func, [param], data)
Expand Down

0 comments on commit 59d8098

Please sign in to comment.