Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for matrix-matrix and matrix-vector products with KFACLinearOperator and KFACInverseLinearOperator without converting to numpy #91

Merged
merged 16 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 67 additions & 7 deletions curvlinops/inverse.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Implements linear operator inverses."""

from typing import Dict, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

from einops import einsum, rearrange
from numpy import allclose, column_stack, ndarray
Expand Down Expand Up @@ -274,20 +274,32 @@ def _compute_or_get_cached_inverse(
self._inverse_gradient_covariances[name] = ggT_inv
return aaT_inv, ggT_inv

def _matmat(self, M: ndarray) -> ndarray:
"""Multiply a matrix ``M`` x by the inverse of KFAC.
def torch_matmat(
self, M_torch: Union[Tensor, List[Tensor]]
) -> Union[Tensor, List[Tensor]]:
"""Apply the inverse of KFAC to a matrix (multiple vectors) in PyTorch.

This allows for matrix-matrix products with the inverse KFAC approximation in
PyTorch without converting tensors to numpy arrays, which avoids unnecessary
device transfers when working with GPUs and flattening/concatenating.

Args:
M: Matrix for multiplication.
M_torch: Matrix for multiplication. If list of tensors, each entry has the
same shape as a parameter with an additional leading dimension of size
``K`` for the columns, i.e. ``[(K,) + p1.shape), (K,) + p2.shape, ...]``.
If tensor, has shape ``[D, K]`` with some ``K``.

Returns:
Result of inverse matrix-matrixmultiplication, ``KFAC⁻¹ @ M``.
Matrix-multiplication result ``KFAC @ M``. Return type is the same as the
type of the input. If list of tensors, each entry has the same shape as a
parameter with an additional leading dimension of size ``K`` for the columns,
i.e. ``[(K,) + p1.shape, (K,) + p2.shape, ...]``. If tensor, has shape
``[D, K]`` with some ``K``.
"""
return_tensor, M_torch = self._A._check_input_type_and_preprocess(M_torch)
if not self._A._input_covariances and not self._A._gradient_covariances:
self._A._compute_kfac()

M_torch = self._A._preprocess(M)

for mod_name, param_pos in self._A._mapping.items():
# retrieve the inverses of the Kronecker factors from cache or invert them
aaT_inv, ggT_inv = self._compute_or_get_cached_inverse(mod_name)
Expand Down Expand Up @@ -318,4 +330,52 @@ def _matmat(self, M: ndarray) -> ndarray:
ggT_inv, M_torch[pos], "i j, m j ... -> m i ..."
)

if return_tensor:
M_torch = cat([rearrange(M, "k ... -> (...) k") for M in M_torch], dim=0)

return M_torch

def torch_matvec(
self, v_torch: Union[Tensor, List[Tensor]]
) -> Union[Tensor, List[Tensor]]:
"""Apply the inverse of KFAC to a vector in PyTorch.

This allows for matrix-vector products with the inverse KFAC approximation in
PyTorch without converting tensors to numpy arrays, which avoids unnecessary
device transfers when working with GPUs and flattening/concatenating.

Args:
v_torch: Vector for multiplication. If list of tensors, each entry has the
same shape as a parameter, i.e. ``[p1.shape, p2.shape, ...]``.
If tensor, has shape ``[D]``.

Returns:
Matrix-multiplication result ``KFAC⁻¹ @ v``. Return type is the same as the
type of the input. If list of tensors, each entry has the same shape as a
parameter, i.e. ``[p1.shape, p2.shape, ...]``. If tensor, has shape ``[D]``.

Raises:
ValueError: If the input tensor has the wrong shape.
runame marked this conversation as resolved.
Show resolved Hide resolved
"""
if isinstance(v_torch, list):
v_torch = [v_torch_i.unsqueeze(0) for v_torch_i in v_torch]
result = self.torch_matmat(v_torch)
return [res.squeeze(0) for res in result]
else:
M = self.shape[0]
if v_torch.shape != (M,):
runame marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("The input vector has the wrong shape.")
runame marked this conversation as resolved.
Show resolved Hide resolved
return self.torch_matmat(v_torch.unsqueeze(1)).squeeze(1)
runame marked this conversation as resolved.
Show resolved Hide resolved

def _matmat(self, M: ndarray) -> ndarray:
"""Apply the inverse of KFAC to a matrix (multiple vectors).

Args:
M: Matrix for multiplication. Has shape ``[D, K]`` with some ``K``.

Returns:
Matrix-multiplication result ``KFAC⁻¹ @ M``. Has shape ``[D, K]``.
"""
M_torch = self._A._preprocess(M)
M_torch = self.torch_matmat(M_torch)
return self._A._postprocess(M_torch)
140 changes: 133 additions & 7 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,20 +245,98 @@ def to_device(self, device: device):
for key in self._gradient_covariances.keys():
self._gradient_covariances[key] = self._gradient_covariances[key].to(device)

def _matmat(self, M: ndarray) -> ndarray:
"""Apply KFAC to a matrix (multiple vectors).
def _torch_preprocess(self, M: Tensor) -> List[Tensor]:
"""Convert torch tensor to torch parameter list format.

Args:
M: Matrix for multiplication. Has shape ``[D, K]`` with some ``K``.
M: Matrix for multiplication. Has shape ``[D, K]`` where ``D`` is the
number of parameters, and ``K`` is the number of columns.

Returns:
Matrix-multiplication result ``KFAC @ M``. Has shape ``[D, K]``.
Matrix in list format. Each entry has the same shape as a parameter with
an additional leading dimension of size ``K`` for the columns, i.e.
``[(K,) + p1.shape, (K,) + p2.shape, ...]``.
"""
num_vectors = M.shape[1]
# split parameter blocks
dims = [p.numel() for p in self._params]
result = M.split(dims)
# column-index first + unflatten parameter dimension
shapes = [(num_vectors,) + p.shape for p in self._params]
result = [res.T.reshape(shape) for res, shape in zip(result, shapes)]
return result
runame marked this conversation as resolved.
Show resolved Hide resolved

def _check_input_type_and_preprocess(
self, M_torch: Union[Tensor, List[Tensor]]
) -> Tuple[bool, List[Tensor]]:
"""Check input type and maybe preprocess to list format.

Check whether the input is a tensor or a list of tensors. If it is a tensor,
preprocess to list format.

Args:
M_torch: Input to check.

Returns:
``True`` if the input is a tensor, ``False`` if it is a list of tensors.

Raises:
ValueError: If the input is a list of tensors that have a different number
of columns.
ValueError: If the input is a list of tensors that have incompatible shapes
with the parameters.
ValueError: If the input is a tensor and has the wrong shape.
ValueError: If the input is a tensor and its shape is incompatible with the
KFAC approximation's shape.
"""
if isinstance(M_torch, list):
return_tensor = False
K = len(M_torch[0])
assert len(M_torch) == len(self._params)
runame marked this conversation as resolved.
Show resolved Hide resolved
for M, p in zip(M_torch, self._params):
if len(M) != K:
raise ValueError(
"All input tensors must have the same number of columns."
)
if M.shape[1:] != p.shape:
raise ValueError(
"All input tensors must have (K,) + the same shape as the parameters."
)
runame marked this conversation as resolved.
Show resolved Hide resolved
else:
return_tensor = True
if M_torch.ndim != 2:
raise ValueError(f"expected 2-d tensor, not {M_torch.ndim}-d")
if M_torch.shape[0] != self.shape[1]:
raise ValueError(f"dimension mismatch: {self.shape}, {M_torch.shape}")
M_torch = self._torch_preprocess(M_torch)
return return_tensor, M_torch

def torch_matmat(
self, M_torch: Union[Tensor, List[Tensor]]
) -> Union[Tensor, List[Tensor]]:
"""Apply KFAC to a matrix (multiple vectors) in PyTorch.

This allows for matrix-matrix products with the KFAC approximation in PyTorch
without converting tensors to numpy arrays, which avoids unnecessary
device transfers when working with GPUs and flattening/concatenating.

Args:
M_torch: Matrix for multiplication. If list of tensors, each entry has the
same shape as a parameter with an additional leading dimension of size
``K`` for the columns, i.e. ``[(K,) + p1.shape), (K,) + p2.shape, ...]``.
If tensor, has shape ``[D, K]`` with some ``K``.

Returns:
Matrix-multiplication result ``KFAC @ M``. Return type is the same as the
type of the input. If list of tensors, each entry has the same shape as a
parameter with an additional leading dimension of size ``K`` for the columns,
i.e. ``[(K,) + p1.shape, (K,) + p2.shape, ...]``. If tensor, has shape
``[D, K]`` with some ``K``.
"""
return_tensor, M_torch = self._check_input_type_and_preprocess(M_torch)
if not self._input_covariances and not self._gradient_covariances:
self._compute_kfac()

M_torch = super()._preprocess(M)

for mod_name, param_pos in self._mapping.items():
# bias and weights are treated jointly
if (
Expand Down Expand Up @@ -295,6 +373,54 @@ def _matmat(self, M: ndarray) -> ndarray:
"j k,v k ... -> v j ...",
)

if return_tensor:
M_torch = cat([rearrange(M, "k ... -> (...) k") for M in M_torch], dim=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (performance): Consider the efficiency of tensor concatenation.

Concatenating tensors in a loop can be inefficient, especially for large numbers of tensors. It might be beneficial to explore alternative approaches that could reduce the computational overhead, such as preallocating a tensor of the correct size and filling it.

runame marked this conversation as resolved.
Show resolved Hide resolved

return M_torch

def torch_matvec(
self, v_torch: Union[Tensor, List[Tensor]]
) -> Union[Tensor, List[Tensor]]:
"""Apply KFAC to a vector in PyTorch.

This allows for matrix-vector products with the KFAC approximation in PyTorch
without converting tensors to numpy arrays, which avoids unnecessary
device transfers when working with GPUs and flattening/concatenating.

Args:
v_torch: Vector for multiplication. If list of tensors, each entry has the
same shape as a parameter, i.e. ``[p1.shape, p2.shape, ...]``.
If tensor, has shape ``[D]``.

Returns:
Matrix-multiplication result ``KFAC @ v``. Return type is the same as the
type of the input. If list of tensors, each entry has the same shape as a
parameter, i.e. ``[p1.shape, p2.shape, ...]``. If tensor, has shape ``[D]``.

Raises:
ValueError: If the input tensor has the wrong shape.
runame marked this conversation as resolved.
Show resolved Hide resolved
"""
if isinstance(v_torch, list):
v_torch = [v_torch_i.unsqueeze(0) for v_torch_i in v_torch]
result = self.torch_matmat(v_torch)
return [res.squeeze(0) for res in result]
else:
M = self.shape[0]
if v_torch.shape != (M,):
raise ValueError("The input vector has the wrong shape.")
runame marked this conversation as resolved.
Show resolved Hide resolved
runame marked this conversation as resolved.
Show resolved Hide resolved
return self.torch_matmat(v_torch.unsqueeze(1)).squeeze(1)
runame marked this conversation as resolved.
Show resolved Hide resolved

def _matmat(self, M: ndarray) -> ndarray:
"""Apply KFAC to a matrix (multiple vectors).

Args:
M: Matrix for multiplication. Has shape ``[D, K]`` with some ``K``.

Returns:
Matrix-multiplication result ``KFAC @ M``. Has shape ``[D, K]``.
"""
M_torch = super()._preprocess(M)
M_torch = self.torch_matmat(M_torch)
return self._postprocess(M_torch)

def _adjoint(self) -> KFACLinearOperator:
Expand All @@ -308,7 +434,7 @@ def _adjoint(self) -> KFACLinearOperator:
return self

def _compute_kfac(self):
"""Compute and cache KFAC's Kronecker factors for future ``matvec``s."""
"""Compute and cache KFAC's Kronecker factors for future ``matmat``s."""
# install forward and backward hooks
hook_handles: List[RemovableHandle] = []

Expand Down
84 changes: 83 additions & 1 deletion test/test_inverse.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Contains tests for ``curvlinops/inverse``."""

import torch
from einops import rearrange
from numpy import array, eye, random
from numpy.linalg import eigh, inv
from pytest import mark, raises
Expand Down Expand Up @@ -156,7 +157,6 @@ def test_KFAC_inverse_damped_matmat(
loss_average=loss_average,
separate_weight_and_bias=separate_weight_and_bias,
)
KFAC._compute_kfac()

# add damping manually
for aaT in KFAC._input_covariances.values():
Expand Down Expand Up @@ -185,3 +185,85 @@ def test_KFAC_inverse_damped_matmat(
# test that the cache is empty
assert len(inv_KFAC._inverse_input_covariances) == 0
assert len(inv_KFAC._inverse_gradient_covariances) == 0


def test_KFAC_inverse_damped_torch_matmat(case, delta: float = 1e-2):
"""Test torch matrix-matrix multiplication by an inverse damped KFAC approximation."""
model_func, loss_func, params, data = case

loss_average = "batch" if loss_func.reduction == "mean" else None
KFAC = KFACLinearOperator(
model_func,
loss_func,
params,
data,
loss_average=loss_average,
)
inv_KFAC = KFACInverseLinearOperator(KFAC, damping=(delta, delta))

num_vectors = 2
X = torch.rand(KFAC.shape[1], num_vectors)
inv_KFAC_X = inv_KFAC.torch_matmat(X)
assert inv_KFAC_X.dtype == X.dtype
assert inv_KFAC_X.device == X.device
f-dangel marked this conversation as resolved.
Show resolved Hide resolved
assert inv_KFAC_X.shape == (KFAC.shape[0], num_vectors)
inv_KFAC_X = inv_KFAC_X.cpu().numpy()

# Test list input format
x_list = KFAC._torch_preprocess(X)
inv_KFAC_x_list = inv_KFAC.torch_matmat(x_list)
inv_KFAC_x_list = torch.cat(
[rearrange(M, "k ... -> (...) k") for M in inv_KFAC_x_list], dim=0
runame marked this conversation as resolved.
Show resolved Hide resolved
)
report_nonclose(inv_KFAC_X, inv_KFAC_x_list.cpu().numpy())

# Test against multiplication with dense matrix
inv_KFAC_mat = inv_KFAC.torch_matmat(torch.eye(inv_KFAC.shape[1]))
runame marked this conversation as resolved.
Show resolved Hide resolved
inv_KFAC_mat_x = inv_KFAC_mat @ X
report_nonclose(inv_KFAC_X, inv_KFAC_mat_x.cpu().numpy(), rtol=5e-4)

# Test against _matmat
kfac_x_numpy = inv_KFAC @ X.cpu().numpy()
report_nonclose(inv_KFAC_X, kfac_x_numpy)


def test_KFAC_inverse_damped_torch_matvec(case, delta: float = 1e-2):
"""Test torch matrix-vector multiplication by an inverse damped KFAC approximation."""
model_func, loss_func, params, data = case

loss_average = "batch" if loss_func.reduction == "mean" else None
KFAC = KFACLinearOperator(
model_func,
loss_func,
params,
data,
loss_average=loss_average,
)
inv_KFAC = KFACInverseLinearOperator(KFAC, damping=(delta, delta))

x = torch.rand(KFAC.shape[1])
inv_KFAC_x = inv_KFAC.torch_matvec(x)
assert inv_KFAC_x.dtype == x.dtype
assert inv_KFAC_x.device == x.device
runame marked this conversation as resolved.
Show resolved Hide resolved
assert inv_KFAC_x.shape == x.shape

# Test list input format
# split parameter blocks
dims = [p.numel() for p in KFAC._params]
split_x = x.split(dims)
# unflatten parameter dimension
assert len(split_x) == len(KFAC._params)
x_list = [res.reshape(p.shape) for res, p in zip(split_x, KFAC._params)]
inv_kfac_x_list = inv_KFAC.torch_matvec(x_list)
inv_kfac_x_list = torch.cat(
[rearrange(M, "... -> (...)") for M in inv_kfac_x_list], dim=0
runame marked this conversation as resolved.
Show resolved Hide resolved
)
report_nonclose(inv_KFAC_x, inv_kfac_x_list.cpu().numpy())

# Test against multiplication with dense matrix
inv_KFAC_mat = inv_KFAC.torch_matmat(torch.eye(inv_KFAC.shape[1]))
runame marked this conversation as resolved.
Show resolved Hide resolved
inv_KFAC_mat_x = inv_KFAC_mat @ x
report_nonclose(inv_KFAC_x.cpu().numpy(), inv_KFAC_mat_x.cpu().numpy(), rtol=5e-5)

# Test against _matmat
report_nonclose(inv_KFAC @ x, inv_KFAC_x.cpu().numpy())
Loading
Loading