Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] Support arbitrarily-ordered params in KFAC #51

Merged
merged 9 commits into from
Oct 30, 2023
33 changes: 15 additions & 18 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def __init__(
Warning:
This is an early proto-type with many limitations:

- Parameters must be in the same order as the model's parameters.
- Only linear layers with bias are supported.
- Weights and biases are treated separately.
- No weight sharing is supported.
Expand All @@ -119,8 +118,6 @@ def __init__(

Raises:
ValueError: If the loss function is not supported.
NotImplementedError: If the parameters are not in the same order as the
model's parameters.
NotImplementedError: If the model contains bias-free linear layers.
NotImplementedError: If any parameter cannot be identified with a supported
layer.
Expand All @@ -130,27 +127,24 @@ def __init__(
f"Invalid loss: {loss_func}. Supported: {self._SUPPORTED_LOSSES}."
)

self.param_ids = [p.data_ptr() for p in params]
self.hooked_modules: List[str] = []
idx = 0
for name, mod in model_func.named_modules():
if isinstance(mod, self._SUPPORTED_MODULES):
# TODO Support bias-free layers
if mod.bias is None:
raise NotImplementedError(
"Bias-free linear layers are not yet supported."
)
# TODO Support arbitrary orders and sub-sets of parameters
if (
params[idx].data_ptr() != mod.weight.data_ptr()
or params[idx + 1].data_ptr() != mod.bias.data_ptr()
):
raise NotImplementedError(
"KFAC parameters must be in same order as model parameters "
+ "for now."
)
idx += 2
self.hooked_modules.append(name)
if idx != len(params):

# check that all parameters are in hooked modules
hooked_param_ids = {
p.data_ptr()
for mod in self.hooked_modules
for p in model_func.get_submodule(mod).parameters()
}
if set(self.param_ids) != hooked_param_ids:
raise NotImplementedError(
"Could not identify all parameters with supported layers."
)
Expand Down Expand Up @@ -188,9 +182,12 @@ def _matvec(self, x: ndarray) -> ndarray:
x_torch = super()._preprocess(x)
assert len(x_torch) % 2 == 0

for idx in range(len(x_torch) // 2):
idx_weight, idx_bias = 2 * idx, 2 * idx + 1
weight, bias = self._params[idx_weight], self._params[idx_bias]
for mod_name in self.hooked_modules:
mod = self._model_func.get_submodule(mod_name)
weight, bias = mod.weight, mod.bias

idx_weight = self.param_ids.index(weight.data_ptr())
idx_bias = self.param_ids.index(bias.data_ptr())
x_weight, x_bias = x_torch[idx_weight], x_torch[idx_bias]

aaT = self._input_covariances[(weight.data_ptr(), bias.data_ptr())]
Expand Down
18 changes: 16 additions & 2 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,36 @@
from typing import Iterable, List, Tuple

from numpy import eye
from pytest import mark
from scipy.linalg import block_diag
from torch import Tensor
from torch import Tensor, randperm
from torch.nn import Module, MSELoss, Parameter

from curvlinops.examples.utils import report_nonclose
from curvlinops.ggn import GGNLinearOperator
from curvlinops.kfac import KFACLinearOperator


@mark.parametrize("shuffle", [False, True], ids=["", "shuffled"])
def test_kfac(
kfac_expand_exact_case: Tuple[
Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
]
],
shuffle: bool,
):
"""Test the KFAC implementation against the exact GGN.

Args:
kfac_expand_exact_case: A fixture that returns a model, loss function, list of
parameters, and data.
shuffle: Whether to shuffle the parameters before computing the KFAC matrix.
"""
model, loss_func, params, data = kfac_expand_exact_case

if shuffle:
permutation = randperm(len(params))
params = [params[i] for i in permutation]

ggn_blocks = [] # list of per-parameter GGNs
for param in params:
ggn = GGNLinearOperator(model, loss_func, [param], data)
Expand Down
Loading