Skip to content

Commit

Permalink
[ADD] Allow subsets of parameters (weight-only, bias-only) in KFAC (#55)
Browse files Browse the repository at this point in the history
* [ADD] Support treating only biases of a layer

* [ADD] Support layers without bias

* [REF] Use a mapping from parameter ids to module names

* [DOC] Remove outdated doc
  • Loading branch information
f-dangel authored Nov 7, 2023
1 parent 238a7e4 commit 2eb654f
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 54 deletions.
118 changes: 64 additions & 54 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations

from math import sqrt
from typing import Dict, Iterable, List, Tuple, Union
from typing import Dict, Iterable, List, Set, Tuple, Union

from einops import rearrange
from numpy import ndarray
Expand Down Expand Up @@ -95,10 +95,9 @@ def __init__(
Warning:
This is an early proto-type with many limitations:
- Only linear layers with bias are supported.
- Only linear layers are supported.
- Weights and biases are treated separately.
- No weight sharing is supported.
- Only the Monte-Carlo sampled version is supported.
- Only the ``'expand'`` approximation is supported.
Args:
Expand Down Expand Up @@ -130,7 +129,6 @@ def __init__(
Raises:
ValueError: If the loss function is not supported.
NotImplementedError: If the model contains bias-free linear layers.
NotImplementedError: If any parameter cannot be identified with a supported
layer.
"""
Expand All @@ -140,33 +138,28 @@ def __init__(
)

self.param_ids = [p.data_ptr() for p in params]
self.hooked_modules: List[str] = []
# mapping from tuples of parameter data pointers in a module to its name
self.param_ids_to_hooked_modules: Dict[Tuple[int, ...], str] = {}

hooked_param_ids: Set[int] = set()
for name, mod in model_func.named_modules():
if isinstance(mod, self._SUPPORTED_MODULES):
# TODO Support bias-free layers
if mod.bias is None:
raise NotImplementedError(
"Bias-free linear layers are not yet supported."
)
self.hooked_modules.append(name)
p_ids = tuple(p.data_ptr() for p in mod.parameters())
if isinstance(mod, self._SUPPORTED_MODULES) and any(
p_id in self.param_ids for p_id in p_ids
):
self.param_ids_to_hooked_modules[p_ids] = name
hooked_param_ids.update(set(p_ids))

# check that all parameters are in hooked modules
hooked_param_ids = {
p.data_ptr()
for mod in self.hooked_modules
for p in model_func.get_submodule(mod).parameters()
}
if set(self.param_ids) != hooked_param_ids:
raise NotImplementedError(
"Could not identify all parameters with supported layers."
)
if not set(self.param_ids).issubset(hooked_param_ids):
raise NotImplementedError("Found parameters outside supported layers.")

self._seed = seed
self._generator: Union[None, Generator] = None
self._fisher_type = fisher_type
self._mc_samples = mc_samples
self._input_covariances: Dict[Tuple[int, ...], Tensor] = {}
self._gradient_covariances: Dict[Tuple[int, ...], Tensor] = {}
self._input_covariances: Dict[str, Tensor] = {}
self._gradient_covariances: Dict[str, Tensor] = {}

super().__init__(
model_func,
Expand All @@ -193,21 +186,20 @@ def _matvec(self, x: ndarray) -> ndarray:
self._compute_kfac()

x_torch = super()._preprocess(x)
assert len(x_torch) % 2 == 0

for mod_name in self.hooked_modules:
mod = self._model_func.get_submodule(mod_name)
weight, bias = mod.weight, mod.bias

idx_weight = self.param_ids.index(weight.data_ptr())
idx_bias = self.param_ids.index(bias.data_ptr())
x_weight, x_bias = x_torch[idx_weight], x_torch[idx_bias]
for name in self.param_ids_to_hooked_modules.values():
mod = self._model_func.get_submodule(name)

aaT = self._input_covariances[(weight.data_ptr(), bias.data_ptr())]
ggT = self._gradient_covariances[(weight.data_ptr(), bias.data_ptr())]
if mod.weight.data_ptr() in self.param_ids:
idx = self.param_ids.index(mod.weight.data_ptr())
aaT = self._input_covariances[name]
ggT = self._gradient_covariances[name]
x_torch[idx] = ggT @ x_torch[idx] @ aaT

x_torch[idx_weight] = ggT @ x_weight @ aaT
x_torch[idx_bias] = ggT @ x_bias
if mod.bias is not None and mod.bias.data_ptr() in self.param_ids:
idx = self.param_ids.index(mod.bias.data_ptr())
ggT = self._gradient_covariances[name]
x_torch[idx] = ggT @ x_torch[idx]

return super()._postprocess(x_torch)

Expand All @@ -231,18 +223,24 @@ def _compute_kfac(self):
"""
# install forward and backward hooks
hook_handles: List[RemovableHandle] = []
hook_handles.extend(
self._model_func.get_submodule(mod).register_forward_pre_hook(
self._hook_accumulate_input_covariance
)
for mod in self.hooked_modules
)
hook_handles.extend(
self._model_func.get_submodule(mod).register_full_backward_hook(
self._hook_accumulate_gradient_covariance

for name in self.param_ids_to_hooked_modules.values():
module = self._model_func.get_submodule(name)

# input covariance only required for weights
if module.weight.data_ptr() in self.param_ids:
hook_handles.append(
module.register_forward_pre_hook(
self._hook_accumulate_input_covariance
)
)

# gradient covariance required for weights and biases
hook_handles.append(
module.register_full_backward_hook(
self._hook_accumulate_gradient_covariance
)
)
for mod in self.hooked_modules
)

# loop over data set, computing the Kronecker factors
if self._generator is None:
Expand Down Expand Up @@ -373,11 +371,11 @@ def _hook_accumulate_gradient_covariance(
+ f"Supported layers: {self._SUPPORTED_MODULES}."
)

idx = tuple(p.data_ptr() for p in module.parameters())
if idx not in self._gradient_covariances:
self._gradient_covariances[idx] = covariance
name = self.get_module_name(module)
if name not in self._gradient_covariances:
self._gradient_covariances[name] = covariance
else:
self._gradient_covariances[idx].add_(covariance)
self._gradient_covariances[name].add_(covariance)

def _hook_accumulate_input_covariance(self, module: Module, inputs: Tuple[Tensor]):
"""Pre-forward hook that accumulates the input covariance of a layer.
Expand Down Expand Up @@ -409,8 +407,20 @@ def _hook_accumulate_input_covariance(self, module: Module, inputs: Tuple[Tensor
# TODO Support convolutions
raise NotImplementedError(f"Layer of type {type(module)} is unsupported.")

idx = tuple(p.data_ptr() for p in module.parameters())
if idx not in self._input_covariances:
self._input_covariances[idx] = covariance
name = self.get_module_name(module)
if name not in self._input_covariances:
self._input_covariances[name] = covariance
else:
self._input_covariances[idx].add_(covariance)
self._input_covariances[name].add_(covariance)

def get_module_name(self, module: Module) -> str:
"""Get the name of a module.
Args:
module: The module.
Returns:
The name of the module.
"""
p_ids = tuple(p.data_ptr() for p in module.parameters())
return self.param_ids_to_hooked_modules[p_ids]
11 changes: 11 additions & 0 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ help:
@echo " Install curvlinops and dependencies"
@echo "uninstall"
@echo " Unstall curvlinops"
@echo "lint"
@echo " Run all linting actions"
@echo "docs"
@echo " Build the documentation"
@echo "install-dev"
Expand Down Expand Up @@ -114,3 +116,12 @@ pydocstyle-check:

conda-env:
@conda env create --file .conda_env.yml

.PHONY: lint

lint:
make black-check
make isort-check
make flake8
make darglint-check
make pydocstyle-check
15 changes: 15 additions & 0 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,33 @@
from curvlinops.kfac import KFACLinearOperator


@mark.parametrize(
"exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"]
)
@mark.parametrize("shuffle", [False, True], ids=["", "shuffled"])
def test_kfac(
kfac_expand_exact_case: Tuple[
Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
],
shuffle: bool,
exclude: str,
):
"""Test the KFAC implementation against the exact GGN.
Args:
kfac_expand_exact_case: A fixture that returns a model, loss function, list of
parameters, and data.
shuffle: Whether to shuffle the parameters before computing the KFAC matrix.
exclude: Which parameters to exclude. Can be ``'weight'``, ``'bias'``,
or ``None``.
"""
assert exclude in [None, "weight", "bias"]
model, loss_func, params, data = kfac_expand_exact_case

if exclude is not None:
names = {p.data_ptr(): name for name, p in model.named_parameters()}
params = [p for p in params if exclude not in names[p.data_ptr()]]

if shuffle:
permutation = randperm(len(params))
params = [params[i] for i in permutation]
Expand All @@ -48,6 +59,10 @@ def test_kfac(

report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol)

# Check that input covariances were not computed
if exclude == "weight":
assert len(kfac._input_covariances) == 0


def test_kfac_one_datum(
kfac_expand_exact_one_datum_case: Tuple[
Expand Down

0 comments on commit 2eb654f

Please sign in to comment.