Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for matrix-matrix and matrix-vector products with KFACLinearOperator and KFACInverseLinearOperator without converting to numpy #91

Merged
merged 16 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 65 additions & 7 deletions curvlinops/inverse.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Implements linear operator inverses."""

from typing import Dict, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

from einops import einsum, rearrange
from numpy import allclose, column_stack, ndarray
Expand Down Expand Up @@ -274,20 +274,38 @@ def _compute_or_get_cached_inverse(
self._inverse_gradient_covariances[name] = ggT_inv
return aaT_inv, ggT_inv

def _matmat(self, M: ndarray) -> ndarray:
"""Multiply a matrix ``M`` x by the inverse of KFAC.
def torch_matmat(
self, M_torch: Union[Tensor, List[Tensor]], return_tensor: bool = True
) -> Union[Tensor, List[Tensor]]:
"""Apply the inverse of KFAC to a matrix (multiple vectors) in PyTorch.

This allows for matrix-matrix products with the inverse KFAC approximation in
PyTorch without converting tensors to numpy arrays, which avoids unnecessary
device transfers when working with GPUs.
runame marked this conversation as resolved.
Show resolved Hide resolved

Args:
M: Matrix for multiplication.
M_torch: Matrix for multiplication. If tensor, has shape ``[D, K]`` with
some ``K``.
runame marked this conversation as resolved.
Show resolved Hide resolved
return_tensor: Whether to return the result as a tensor or list of tensors.
runame marked this conversation as resolved.
Show resolved Hide resolved

Returns:
Result of inverse matrix-matrixmultiplication, ``KFAC⁻¹ @ M``.
Matrix-multiplication result ``KFAC⁻¹ @ M``. If tensor, has shape ``[D, K]``.
runame marked this conversation as resolved.
Show resolved Hide resolved

Raises:
ValueError: If the input tensor has the wrong shape.
ValueError: If the input tensor's shape is incompatible with the KFAC
approximation's shape.
"""
if not isinstance(M_torch, list):
if M_torch.ndim != 2:
raise ValueError(f"expected 2-d tensor, not {M_torch.ndim}-d")
if M_torch.shape[0] != self.shape[1]:
raise ValueError(f"dimension mismatch: {self.shape}, {M_torch.shape}")
M_torch = self._A._torch_preprocess(M_torch)

runame marked this conversation as resolved.
Show resolved Hide resolved
if not self._A._input_covariances and not self._A._gradient_covariances:
self._A._compute_kfac()

M_torch = self._A._preprocess(M)

for mod_name, param_pos in self._A._mapping.items():
# retrieve the inverses of the Kronecker factors from cache or invert them
aaT_inv, ggT_inv = self._compute_or_get_cached_inverse(mod_name)
Expand Down Expand Up @@ -318,4 +336,44 @@ def _matmat(self, M: ndarray) -> ndarray:
ggT_inv, M_torch[pos], "i j, m j ... -> m i ..."
)

if return_tensor:
M_torch = cat([rearrange(M, "k ... -> (...) k") for M in M_torch], dim=0)

return M_torch

def torch_matvec(
self, v_torch: Tensor, return_tensor: bool = True
) -> Union[Tensor, List[Tensor]]:
"""Apply the inverse of KFAC to a vector in PyTorch.

This allows for matrix-vector products with the inverse KFAC approximation in
PyTorch without converting tensors to numpy arrays, which avoids unnecessary
device transfers when working with GPUs.

Args:
v_torch: Vector for multiplication. Has shape ``[D]``.
runame marked this conversation as resolved.
Show resolved Hide resolved
return_tensor: Whether to return the result as a tensor or list of tensors.
runame marked this conversation as resolved.
Show resolved Hide resolved

Returns:
Matrix-multiplication result ``KFAC⁻¹ @ v``. If tensor, has shape ``[D]``.

Raises:
ValueError: If the input tensor has the wrong shape.
runame marked this conversation as resolved.
Show resolved Hide resolved
"""
M = self.shape[0]
if v_torch.shape != (M,) and v_torch.shape != (M, 1):
raise ValueError("dimension mismatch")
runame marked this conversation as resolved.
Show resolved Hide resolved
return self.torch_matmat(v_torch.view(-1, 1), return_tensor).squeeze(1)

def _matmat(self, M: ndarray) -> ndarray:
"""Apply the inverse of KFAC to a matrix (multiple vectors).

Args:
M: Matrix for multiplication. Has shape ``[D, K]`` with some ``K``.

Returns:
Matrix-multiplication result ``KFAC⁻¹ @ M``. Has shape ``[D, K]``.
"""
M_torch = self._A._preprocess(M)
M_torch = self.torch_matmat(M_torch, return_tensor=False)
return self._A._postprocess(M_torch)
100 changes: 94 additions & 6 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from functools import partial
from math import sqrt
from typing import Dict, Iterable, List, Optional, Tuple, Union
from warnings import warn

from einops import einsum, rearrange, reduce
from numpy import ndarray
Expand Down Expand Up @@ -245,20 +246,67 @@ def to_device(self, device: device):
for key in self._gradient_covariances.keys():
self._gradient_covariances[key] = self._gradient_covariances[key].to(device)

def _matmat(self, M: ndarray) -> ndarray:
"""Apply KFAC to a matrix (multiple vectors).
def _torch_preprocess(self, M: Tensor) -> List[Tensor]:
"""Convert torch tensor to torch parameter list format.

Args:
M: Matrix for multiplication. Has shape ``[D, K]`` with some ``K``.
M: Matrix for multiplication. Has shape ``[D, K]`` where ``D`` is the
number of parameters, and ``K`` is the number of columns.

Returns:
Matrix-multiplication result ``KFAC @ M``. Has shape ``[D, K]``.
Matrix in list format. Each entry has the same shape as a parameter with
an additional leading dimension of size ``K`` for the columns, i.e.
``[(K,) + p1.shape), (K,) + p2.shape, ...]``.
runame marked this conversation as resolved.
Show resolved Hide resolved
"""
if M.device != self._device:
warn(
f"Input matrix is on {M.device}, while linear operator is on "
+ f"{self._device}. Converting to {self._device}."
)
M = M.to(self._device)
runame marked this conversation as resolved.
Show resolved Hide resolved
runame marked this conversation as resolved.
Show resolved Hide resolved

num_vectors = M.shape[1]
# split parameter blocks
dims = [p.numel() for p in self._params]
result = M.split(dims)
# column-index first + unflatten parameter dimension
shapes = [(num_vectors,) + p.shape for p in self._params]
result = [res.T.reshape(shape) for res, shape in zip(result, shapes)]

return result

def torch_matmat(
self, M_torch: Union[Tensor, List[Tensor]], return_tensor: bool = True
) -> Union[Tensor, List[Tensor]]:
"""Apply KFAC to a matrix (multiple vectors) in PyTorch.
runame marked this conversation as resolved.
Show resolved Hide resolved

This allows for matrix-matrix products with the KFAC approximation in PyTorch
without converting tensors to numpy arrays, which avoids unnecessary
device transfers when working with GPUs.

Args:
M_torch: Matrix for multiplication. If tensor, has shape ``[D, K]`` with
some ``K``.
return_tensor: Whether to return the result as a tensor or list of tensors.

Returns:
Matrix-multiplication result ``KFAC @ M``. If tensor, has shape ``[D, K]``.

Raises:
ValueError: If the input tensor has the wrong shape.
ValueError: If the input tensor's shape is incompatible with the KFAC
approximation's shape.
"""
if not isinstance(M_torch, list):
if M_torch.ndim != 2:
raise ValueError(f"expected 2-d tensor, not {M_torch.ndim}-d")
if M_torch.shape[0] != self.shape[1]:
raise ValueError(f"dimension mismatch: {self.shape}, {M_torch.shape}")
runame marked this conversation as resolved.
Show resolved Hide resolved
M_torch = self._torch_preprocess(M_torch)

if not self._input_covariances and not self._gradient_covariances:
self._compute_kfac()

M_torch = super()._preprocess(M)

for mod_name, param_pos in self._mapping.items():
# bias and weights are treated jointly
if (
Expand Down Expand Up @@ -295,6 +343,46 @@ def _matmat(self, M: ndarray) -> ndarray:
"j k,v k ... -> v j ...",
)

if return_tensor:
M_torch = cat([rearrange(M, "k ... -> (...) k") for M in M_torch], dim=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (performance): Consider the efficiency of tensor concatenation.

Concatenating tensors in a loop can be inefficient, especially for large numbers of tensors. It might be beneficial to explore alternative approaches that could reduce the computational overhead, such as preallocating a tensor of the correct size and filling it.

runame marked this conversation as resolved.
Show resolved Hide resolved

return M_torch

def torch_matvec(
self, v_torch: Tensor, return_tensor: bool = True
) -> Union[Tensor, List[Tensor]]:
"""Apply KFAC to a vector in PyTorch.
runame marked this conversation as resolved.
Show resolved Hide resolved

This allows for matrix-vector products with the KFAC approximation in PyTorch
without converting tensors to numpy arrays, which avoids unnecessary
device transfers when working with GPUs.

Args:
v_torch: Vector for multiplication. Has shape ``[D]``.
return_tensor: Whether to return the result as a tensor or list of tensors.

Returns:
Matrix-multiplication result ``KFAC @ v``. If tensor, has shape ``[D]``.

Raises:
ValueError: If the input tensor has the wrong shape.
runame marked this conversation as resolved.
Show resolved Hide resolved
"""
M = self.shape[0]
if v_torch.shape not in [(M,), (M, 1)]:
raise ValueError("dimension mismatch")
return self.torch_matmat(v_torch.view(-1, 1), return_tensor).squeeze(1)

def _matmat(self, M: ndarray) -> ndarray:
"""Apply KFAC to a matrix (multiple vectors).

Args:
M: Matrix for multiplication. Has shape ``[D, K]`` with some ``K``.

Returns:
Matrix-multiplication result ``KFAC @ M``. Has shape ``[D, K]``.
"""
M_torch = super()._preprocess(M)
M_torch = self.torch_matmat(M_torch, return_tensor=False)
return self._postprocess(M_torch)

def _adjoint(self) -> KFACLinearOperator:
Expand Down
60 changes: 59 additions & 1 deletion test/test_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def test_KFAC_inverse_damped_matmat(
loss_average=loss_average,
separate_weight_and_bias=separate_weight_and_bias,
)
KFAC._compute_kfac()

# add damping manually
for aaT in KFAC._input_covariances.values():
Expand Down Expand Up @@ -185,3 +184,62 @@ def test_KFAC_inverse_damped_matmat(
# test that the cache is empty
assert len(inv_KFAC._inverse_input_covariances) == 0
assert len(inv_KFAC._inverse_gradient_covariances) == 0


def test_KFAC_inverse_damped_torch_matmat(case, delta: float = 1e-2):
"""Test torch matrix-matrix multiplication by an inverse damped KFAC approximation."""
model_func, loss_func, params, data = case

loss_average = "batch" if loss_func.reduction == "mean" else None
KFAC = KFACLinearOperator(
model_func,
loss_func,
params,
data,
loss_average=loss_average,
)
inv_KFAC = KFACInverseLinearOperator(KFAC, damping=(delta, delta))

num_vectors = 2
X = torch.rand(KFAC.shape[1], num_vectors)
inv_KFAC_X = inv_KFAC.torch_matmat(X)
assert inv_KFAC_X.dtype == X.dtype
assert inv_KFAC_X.device == X.device
f-dangel marked this conversation as resolved.
Show resolved Hide resolved
assert inv_KFAC_X.shape == (KFAC.shape[0], num_vectors)

# Test against multiplication with dense matrix
inv_KFAC_mat = inv_KFAC.torch_matmat(torch.eye(inv_KFAC.shape[1]))
runame marked this conversation as resolved.
Show resolved Hide resolved
inv_KFAC_mat_x = inv_KFAC_mat @ X
report_nonclose(inv_KFAC_X.cpu().numpy(), inv_KFAC_mat_x.cpu().numpy(), rtol=5e-4)

# Test against _matmat
report_nonclose(inv_KFAC @ X, inv_KFAC_X.cpu().numpy())


def test_KFAC_inverse_damped_torch_matvec(case, delta: float = 1e-2):
"""Test torch matrix-vector multiplication by an inverse damped KFAC approximation."""
model_func, loss_func, params, data = case

loss_average = "batch" if loss_func.reduction == "mean" else None
KFAC = KFACLinearOperator(
model_func,
loss_func,
params,
data,
loss_average=loss_average,
)
inv_KFAC = KFACInverseLinearOperator(KFAC, damping=(delta, delta))

x = torch.rand(KFAC.shape[1])
inv_KFAC_X = inv_KFAC.torch_matvec(x)
assert inv_KFAC_X.dtype == x.dtype
assert inv_KFAC_X.device == x.device
assert inv_KFAC_X.shape == x.shape

# Test against multiplication with dense matrix
inv_KFAC_mat = inv_KFAC.torch_matmat(torch.eye(inv_KFAC.shape[1]))
runame marked this conversation as resolved.
Show resolved Hide resolved
inv_KFAC_mat_x = inv_KFAC_mat @ x
report_nonclose(inv_KFAC_X.cpu().numpy(), inv_KFAC_mat_x.cpu().numpy(), rtol=5e-5)

# Test against _matmat
report_nonclose(inv_KFAC @ x, inv_KFAC_X.cpu().numpy())
65 changes: 64 additions & 1 deletion test/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from numpy import eye
from pytest import mark, skip
from scipy.linalg import block_diag
from torch import Tensor, cuda, device, manual_seed, rand, randperm
from torch import Tensor, cuda, device
from torch import eye as torch_eye
from torch import manual_seed, rand, randperm
from torch.nn import (
CrossEntropyLoss,
Flatten,
Expand Down Expand Up @@ -426,3 +428,64 @@ def test_bug_device_change_invalidates_parameter_mapping():
kfac_x_cpu = kfac @ x

report_nonclose(kfac_x_gpu, kfac_x_cpu)


@mark.parametrize("dev", DEVICES, ids=DEVICES_IDS)
def test_torch_matmat(dev: device):
runame marked this conversation as resolved.
Show resolved Hide resolved
"""Test that the torch_matmat method of KFACLinearOperator works."""
manual_seed(0)

data = [(rand(2, 5, device=dev), regression_targets((2, 4)).to(dev))]
model = Sequential(Linear(5, 4), ReLU(), Linear(4, 4)).to(dev)
loss_func = MSELoss().to(dev)

kfac = KFACLinearOperator(
model,
loss_func,
list(model.parameters()),
data,
fisher_type="empirical",
)

x = rand(kfac.shape[1], 16, device=dev)
runame marked this conversation as resolved.
Show resolved Hide resolved
kfac_x = kfac.torch_matmat(x)
assert x.device == kfac_x.device
f-dangel marked this conversation as resolved.
Show resolved Hide resolved
assert x.dtype == kfac_x.dtype
assert kfac_x.shape == (kfac.shape[0], x.shape[1])

kfac_mat = kfac.torch_matmat(torch_eye(kfac.shape[1], device=dev))
kfac_mat_x = kfac_mat @ x
report_nonclose(kfac_x.cpu().numpy(), kfac_mat_x.cpu().numpy())

kfac_x_numpy = kfac @ x.cpu().numpy()
report_nonclose(kfac_x.cpu().numpy(), kfac_x_numpy)


@mark.parametrize("dev", DEVICES, ids=DEVICES_IDS)
def test_torch_matvec(dev: device):
"""Test that the torch_matvec method of KFACLinearOperator works."""
manual_seed(0)

data = [(rand(2, 5, device=dev), regression_targets((2, 4)).to(dev))]
model = Sequential(Linear(5, 4), ReLU(), Linear(4, 4)).to(dev)
loss_func = MSELoss().to(dev)

kfac = KFACLinearOperator(
model,
loss_func,
list(model.parameters()),
data,
)

x = rand(kfac.shape[1], device=dev)
kfac_x = kfac.torch_matvec(x)
assert x.device == kfac_x.device
assert x.dtype == kfac_x.dtype
assert kfac_x.shape == x.shape

kfac_mat = kfac.torch_matmat(torch_eye(kfac.shape[1], device=dev))
kfac_mat_x = kfac_mat @ x
report_nonclose(kfac_x.cpu().numpy(), kfac_mat_x.cpu().numpy())

kfac_x_numpy = kfac @ x.cpu().numpy()
report_nonclose(kfac_x.cpu().numpy(), kfac_x_numpy)
Loading