Skip to content

Commit

Permalink
[ADD] Rewrite inverse KFAC's _matvec to _matmat
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Feb 7, 2024
1 parent a04bd49 commit 3d7685f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 23 deletions.
41 changes: 22 additions & 19 deletions curvlinops/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Dict, Optional, Tuple

from einops import rearrange
from einops import einsum, rearrange
from numpy import allclose, column_stack, ndarray
from scipy.sparse.linalg import LinearOperator, cg
from torch import Tensor, cat, cholesky_inverse, eye
Expand Down Expand Up @@ -274,19 +274,19 @@ def _compute_or_get_cached_inverse(
self._inverse_gradient_covariances[name] = ggT_inv
return aaT_inv, ggT_inv

def _matvec(self, x: ndarray) -> ndarray:
"""Multiply x by the inverse of A.
def _matmat(self, M: ndarray) -> ndarray:
"""Multiply a matrix ``M`` x by the inverse of KFAC.
Args:
x: Vector for multiplication.
M: Matrix for multiplication.
Returns:
Result of inverse matrix-vector multiplication, ``A⁻¹ @ x``.
Result of inverse matrix-matrixmultiplication, ``KFAC⁻¹ @ M``.
"""
if not self._A._input_covariances and not self._A._gradient_covariances:
self._A._compute_kfac()

x_torch = self._A._preprocess(x)
M_torch = self._A._preprocess(M)

for name in self._A.param_ids_to_hooked_modules.values():
mod = self._A._model_func.get_submodule(name)
Expand All @@ -295,18 +295,17 @@ def _matvec(self, x: ndarray) -> ndarray:
aaT_inv, ggT_inv = self._compute_or_get_cached_inverse(name)

# bias and weights are treated jointly
weight, bias = mod.weight, mod.bias
if not self._A._separate_weight_and_bias and self._A.in_params(
mod.weight, mod.bias
weight, bias
):
w_pos, b_pos = self._A.param_pos(mod.weight), self._A.param_pos(
mod.bias
)
x_w = rearrange(x_torch[w_pos], "c_out ... -> c_out (...)")
x_joint = cat([x_w, x_torch[b_pos].unsqueeze(-1)], dim=1)
x_joint = ggT_inv @ x_joint @ aaT_inv
w_pos, b_pos = self._A.param_pos(weight), self._A.param_pos(bias)
M_w = rearrange(M_torch[w_pos], "m c_out ... -> m c_out (...)")
M_joint = cat([M_w, M_torch[b_pos].unsqueeze(2)], dim=2)
M_joint = einsum(ggT_inv, M_joint, aaT_inv, "i j,m j k, k l -> m i l")

w_cols = x_w.shape[1]
x_torch[w_pos], x_torch[b_pos] = x_joint.split([w_cols, 1], dim=1)
w_cols = M_w.shape[2]
M_torch[w_pos], M_torch[b_pos] = M_joint.split([w_cols, 1], dim=2)

# for weights we need to multiply from the right with aaT
# for weights and biases we need to multiply from the left with ggT
Expand All @@ -317,9 +316,13 @@ def _matvec(self, x: ndarray) -> ndarray:
pos = self._A.param_pos(p)

if p_name == "weight":
x_w = rearrange(x_torch[pos], "c_out ... -> c_out (...)")
x_torch[pos] = x_w @ aaT_inv
M_w = rearrange(
M_torch[pos], "m c_out ... -> m c_out (...)"
)
M_torch[pos] = einsum(M_w, aaT_inv, "m i j, j k -> m i k")

x_torch[pos] = ggT_inv @ x_torch[pos]
M_torch[pos] = einsum(
ggT_inv, M_torch[pos], "i j, m j ... -> m i ..."
)

return self._A._postprocess(x_torch)
return self._A._postprocess(M_torch)
9 changes: 5 additions & 4 deletions test/test_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ def test_NeumannInverseLinearOperator_toy():
@mark.parametrize(
"separate_weight_and_bias", [True, False], ids=["separate_bias", "joint_bias"]
)
def test_KFAC_inverse_damped_matvec(
def test_KFAC_inverse_damped_matmat(
case, cache: bool, exclude: str, separate_weight_and_bias: bool, delta: float = 1e-2
):
"""Test matrix-vector multiplication by an inverse damped KFAC approximation."""
"""Test matrix-matrix multiplication by an inverse damped KFAC approximation."""
model_func, loss_func, params, data = case

if exclude is not None:
Expand Down Expand Up @@ -172,8 +172,9 @@ def test_KFAC_inverse_damped_matvec(
ggT.sub_(torch.eye(ggT.shape[0], device=ggT.device), alpha=delta)
inv_KFAC = KFACInverseLinearOperator(KFAC, damping=(delta, delta), cache=cache)

x = random.rand(KFAC.shape[1])
report_nonclose(inv_KFAC @ x, inv_KFAC_naive @ x, rtol=5e-2)
num_vectors = 2
X = random.rand(KFAC.shape[1], num_vectors)
report_nonclose(inv_KFAC @ X, inv_KFAC_naive @ X, rtol=5e-2)

assert inv_KFAC._cache == cache
if cache:
Expand Down

0 comments on commit 3d7685f

Please sign in to comment.