Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for matrix-matrix and matrix-vector products with KFACLinearOperator and KFACInverseLinearOperator without converting to numpy #91

Merged
merged 16 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 94 additions & 6 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from functools import partial
from math import sqrt
from typing import Dict, Iterable, List, Optional, Tuple, Union
from warnings import warn

from einops import einsum, rearrange, reduce
from numpy import ndarray
Expand Down Expand Up @@ -245,20 +246,67 @@ def to_device(self, device: device):
for key in self._gradient_covariances.keys():
self._gradient_covariances[key] = self._gradient_covariances[key].to(device)

def _matmat(self, M: ndarray) -> ndarray:
"""Apply KFAC to a matrix (multiple vectors).
def _torch_preprocess(self, M: Tensor) -> List[Tensor]:
"""Convert torch tensor to torch parameter list format.

Args:
M: Matrix for multiplication. Has shape ``[D, K]`` with some ``K``.
M: Matrix for multiplication. Has shape ``[D, K]`` where ``D`` is the
number of parameters, and ``K`` is the number of columns.

Returns:
Matrix-multiplication result ``KFAC @ M``. Has shape ``[D, K]``.
Matrix in list format. Each entry has the same shape as a parameter with
an additional leading dimension of size ``K`` for the columns, i.e.
``[(K,) + p1.shape), (K,) + p2.shape, ...]``.
runame marked this conversation as resolved.
Show resolved Hide resolved
"""
if M.device != self._device:
warn(
f"Input matrix is on {M.device}, while linear operator is on "
+ f"{self._device}. Converting to {self._device}."
)
M = M.to(self._device)
runame marked this conversation as resolved.
Show resolved Hide resolved
runame marked this conversation as resolved.
Show resolved Hide resolved

num_vectors = M.shape[1]
# split parameter blocks
dims = [p.numel() for p in self._params]
result = M.split(dims)
# column-index first + unflatten parameter dimension
shapes = [(num_vectors,) + p.shape for p in self._params]
result = [res.T.reshape(shape) for res, shape in zip(result, shapes)]

return result

def torch_matmat(
self, M_torch: Union[Tensor, List[Tensor]], return_tensor: bool = True
) -> Union[Tensor, List[Tensor]]:
"""Apply KFAC to a matrix (multiple vectors) in PyTorch.
runame marked this conversation as resolved.
Show resolved Hide resolved

This allows for matrix-matrix products with the KFAC approximation in PyTorch
without converting tensors to numpy arrays, which avoids unnecessary
device transfers when working with GPUs.

Args:
M_torch: Matrix for multiplication. If tensor, has shape ``[D, K]`` with
some ``K``.
return_tensor: Whether to return the result as a tensor or list of tensors.

Returns:
Matrix-multiplication result ``KFAC @ M``. If tensor, has shape ``[D, K]``.

Raises:
ValueError: If the input tensor has the wrong shape.
ValueError: If the input tensor's shape is incompatible with the KFAC
approximation's shape.
"""
if not isinstance(M_torch, list):
if M_torch.ndim != 2:
raise ValueError(f"expected 2-d tensor, not {M_torch.ndim}-d")
if M_torch.shape[0] != self.shape[1]:
raise ValueError(f"dimension mismatch: {self.shape}, {M_torch.shape}")
runame marked this conversation as resolved.
Show resolved Hide resolved
M_torch = self._torch_preprocess(M_torch)

if not self._input_covariances and not self._gradient_covariances:
self._compute_kfac()

M_torch = super()._preprocess(M)

for mod_name, param_pos in self._mapping.items():
# bias and weights are treated jointly
if (
Expand Down Expand Up @@ -295,6 +343,46 @@ def _matmat(self, M: ndarray) -> ndarray:
"j k,v k ... -> v j ...",
)

if return_tensor:
M_torch = cat([rearrange(M, "k ... -> (...) k") for M in M_torch], dim=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (performance): Consider the efficiency of tensor concatenation.

Concatenating tensors in a loop can be inefficient, especially for large numbers of tensors. It might be beneficial to explore alternative approaches that could reduce the computational overhead, such as preallocating a tensor of the correct size and filling it.

runame marked this conversation as resolved.
Show resolved Hide resolved

return M_torch

def torch_matvec(
self, v_torch: Tensor, return_tensor: bool = True
) -> Union[Tensor, List[Tensor]]:
"""Apply KFAC to a vector in PyTorch.
runame marked this conversation as resolved.
Show resolved Hide resolved

This allows for matrix-vector products with the KFAC approximation in PyTorch
without converting tensors to numpy arrays, which avoids unnecessary
device transfers when working with GPUs.

Args:
v_torch: Vector for multiplication.
return_tensor: Whether to return the result as a tensor or list of tensors.

Returns:
Matrix-multiplication result ``KFAC @ M``. If tensor, has shape ``[D, K]``.

Raises:
ValueError: If the input tensor has the wrong shape.
runame marked this conversation as resolved.
Show resolved Hide resolved
"""
M = self.shape[0]
if v_torch.shape != (M,) and v_torch.shape != (M, 1):
runame marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("dimension mismatch")
runame marked this conversation as resolved.
Show resolved Hide resolved
return self.torch_matmat(v_torch.view(-1, 1), return_tensor).squeeze(1)

def _matmat(self, M: ndarray) -> ndarray:
"""Apply KFAC to a matrix (multiple vectors).

Args:
M: Matrix for multiplication. Has shape ``[D, K]`` with some ``K``.

Returns:
Matrix-multiplication result ``KFAC @ M``. Has shape ``[D, K]``.
"""
M_torch = super()._preprocess(M)
M_torch = self.torch_matmat(M_torch, return_tensor=False)
return self._postprocess(M_torch)

def _adjoint(self) -> KFACLinearOperator:
Expand Down
65 changes: 64 additions & 1 deletion test/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from numpy import eye
from pytest import mark, skip
from scipy.linalg import block_diag
from torch import Tensor, cuda, device, manual_seed, rand, randperm
from torch import Tensor, cuda, device
from torch import eye as torch_eye
from torch import manual_seed, rand, randperm
from torch.nn import (
CrossEntropyLoss,
Flatten,
Expand Down Expand Up @@ -426,3 +428,64 @@ def test_bug_device_change_invalidates_parameter_mapping():
kfac_x_cpu = kfac @ x

report_nonclose(kfac_x_gpu, kfac_x_cpu)


@mark.parametrize("dev", DEVICES, ids=DEVICES_IDS)
def test_torch_matmat(dev: device):
runame marked this conversation as resolved.
Show resolved Hide resolved
"""Test that the torch_matmat method of KFACLinearOperator works."""
manual_seed(0)

data = [(rand(2, 5, device=dev), regression_targets((2, 4)).to(dev))]
model = Sequential(Linear(5, 4), ReLU(), Linear(4, 4)).to(dev)
loss_func = MSELoss().to(dev)

kfac = KFACLinearOperator(
model,
loss_func,
list(model.parameters()),
data,
fisher_type="empirical",
)

x = rand(kfac.shape[1], 16, device=dev)
runame marked this conversation as resolved.
Show resolved Hide resolved
kfac_x = kfac.torch_matmat(x)
assert x.device == kfac_x.device
f-dangel marked this conversation as resolved.
Show resolved Hide resolved
assert x.dtype == kfac_x.dtype
assert kfac_x.shape == (kfac.shape[0], x.shape[1])

kfac_mat = kfac.torch_matmat(torch_eye(kfac.shape[1], device=dev))
kfac_mat_x = kfac_mat @ x
report_nonclose(kfac_x.cpu().numpy(), kfac_mat_x.cpu().numpy())

kfac_x_numpy = kfac @ x.cpu().numpy()
report_nonclose(kfac_x.cpu().numpy(), kfac_x_numpy)


@mark.parametrize("dev", DEVICES, ids=DEVICES_IDS)
def test_torch_matvec(dev: device):
"""Test that the torch_matvec method of KFACLinearOperator works."""
manual_seed(0)

data = [(rand(2, 5, device=dev), regression_targets((2, 4)).to(dev))]
model = Sequential(Linear(5, 4), ReLU(), Linear(4, 4)).to(dev)
loss_func = MSELoss().to(dev)

kfac = KFACLinearOperator(
model,
loss_func,
list(model.parameters()),
data,
)

x = rand(kfac.shape[1], device=dev)
kfac_x = kfac.torch_matvec(x)
assert x.device == kfac_x.device
assert x.dtype == kfac_x.dtype
assert kfac_x.shape == x.shape

kfac_mat = kfac.torch_matmat(torch_eye(kfac.shape[1], device=dev))
kfac_mat_x = kfac_mat @ x
report_nonclose(kfac_x.cpu().numpy(), kfac_mat_x.cpu().numpy())

kfac_x_numpy = kfac @ x.cpu().numpy()
report_nonclose(kfac_x.cpu().numpy(), kfac_x_numpy)
Loading