Skip to content

Commit

Permalink
Allow for matrix-matrix and matrix-vector products with `KFACLinearOp…
Browse files Browse the repository at this point in the history
…erator` and `KFACInverseLinearOperator` without converting to numpy (#91)

* Add tests for torch_matmat and torch_matvec

* Add torch_matmat and torch_matvec methods to KFACLinearOperator

* Nicer if-statement

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Fix doc string

* Add tests for torch_matmat and torch_matvec for KFACInverseLinearOperator

* Add torch_matmat and torch_matvec to KFACInverseLinearOperator

* Increase test tolerance

* Deduce torch_matmat return type from input type

* Remove device check

* Make constant explicit

* Remove leftover import

* Add support for list inputs to torch_matvec

* Add tests for list inputs to torch_matmat and torch_matvec

* Address review feedback

* Switch to case for kfac tests and minor fixes

* Fix flake8

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
  • Loading branch information
runame and sourcery-ai[bot] authored Mar 20, 2024
1 parent 0c84391 commit 67e043f
Show file tree
Hide file tree
Showing 4 changed files with 391 additions and 17 deletions.
75 changes: 68 additions & 7 deletions curvlinops/inverse.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Implements linear operator inverses."""

from typing import Dict, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

from einops import einsum, rearrange
from numpy import allclose, column_stack, ndarray
Expand Down Expand Up @@ -274,20 +274,32 @@ def _compute_or_get_cached_inverse(
self._inverse_gradient_covariances[name] = ggT_inv
return aaT_inv, ggT_inv

def _matmat(self, M: ndarray) -> ndarray:
"""Multiply a matrix ``M`` x by the inverse of KFAC.
def torch_matmat(
self, M_torch: Union[Tensor, List[Tensor]]
) -> Union[Tensor, List[Tensor]]:
"""Apply the inverse of KFAC to a matrix (multiple vectors) in PyTorch.
This allows for matrix-matrix products with the inverse KFAC approximation in
PyTorch without converting tensors to numpy arrays, which avoids unnecessary
device transfers when working with GPUs and flattening/concatenating.
Args:
M: Matrix for multiplication.
M_torch: Matrix for multiplication. If list of tensors, each entry has the
same shape as a parameter with an additional leading dimension of size
``K`` for the columns, i.e. ``[(K,) + p1.shape), (K,) + p2.shape, ...]``.
If tensor, has shape ``[D, K]`` with some ``K``.
Returns:
Result of inverse matrix-matrixmultiplication, ``KFAC⁻¹ @ M``.
Matrix-multiplication result ``KFAC @ M``. Return type is the same as the
type of the input. If list of tensors, each entry has the same shape as a
parameter with an additional leading dimension of size ``K`` for the columns,
i.e. ``[(K,) + p1.shape, (K,) + p2.shape, ...]``. If tensor, has shape
``[D, K]`` with some ``K``.
"""
return_tensor, M_torch = self._A._check_input_type_and_preprocess(M_torch)
if not self._A._input_covariances and not self._A._gradient_covariances:
self._A._compute_kfac()

M_torch = self._A._preprocess(M)

for mod_name, param_pos in self._A._mapping.items():
# retrieve the inverses of the Kronecker factors from cache or invert them
aaT_inv, ggT_inv = self._compute_or_get_cached_inverse(mod_name)
Expand Down Expand Up @@ -318,4 +330,53 @@ def _matmat(self, M: ndarray) -> ndarray:
ggT_inv, M_torch[pos], "i j, m j ... -> m i ..."
)

if return_tensor:
M_torch = cat([rearrange(M, "k ... -> (...) k") for M in M_torch])

return M_torch

def torch_matvec(
self, v_torch: Union[Tensor, List[Tensor]]
) -> Union[Tensor, List[Tensor]]:
"""Apply the inverse of KFAC to a vector in PyTorch.
This allows for matrix-vector products with the inverse KFAC approximation in
PyTorch without converting tensors to numpy arrays, which avoids unnecessary
device transfers when working with GPUs and flattening/concatenating.
Args:
v_torch: Vector for multiplication. If list of tensors, each entry has the
same shape as a parameter, i.e. ``[p1.shape, p2.shape, ...]``.
If tensor, has shape ``[D]``.
Returns:
Matrix-multiplication result ``KFAC⁻¹ @ v``. Return type is the same as the
type of the input. If list of tensors, each entry has the same shape as a
parameter, i.e. ``[p1.shape, p2.shape, ...]``. If tensor, has shape ``[D]``.
Raises:
ValueError: If the input tensor has the wrong data type.
"""
if isinstance(v_torch, list):
v_torch = [v_torch_i.unsqueeze(0) for v_torch_i in v_torch]
result = self.torch_matmat(v_torch)
return [res.squeeze(0) for res in result]
elif isinstance(v_torch, Tensor):
return self.torch_matmat(v_torch.unsqueeze(-1)).squeeze(-1)
else:
raise ValueError(
f"Invalid input type: {type(v_torch)}. Expected list of tensors or tensor."
)

def _matmat(self, M: ndarray) -> ndarray:
"""Apply the inverse of KFAC to a matrix (multiple vectors).
Args:
M: Matrix for multiplication. Has shape ``[D, K]`` with some ``K``.
Returns:
Matrix-multiplication result ``KFAC⁻¹ @ M``. Has shape ``[D, K]``.
"""
M_torch = self._A._preprocess(M)
M_torch = self.torch_matmat(M_torch)
return self._A._postprocess(M_torch)
146 changes: 139 additions & 7 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,20 +245,103 @@ def to_device(self, device: device):
for key in self._gradient_covariances.keys():
self._gradient_covariances[key] = self._gradient_covariances[key].to(device)

def _matmat(self, M: ndarray) -> ndarray:
"""Apply KFAC to a matrix (multiple vectors).
def _torch_preprocess(self, M: Tensor) -> List[Tensor]:
"""Convert torch tensor to torch parameter list format.
Args:
M: Matrix for multiplication. Has shape ``[D, K]`` with some ``K``.
M: Matrix for multiplication. Has shape ``[D, K]`` where ``D`` is the
number of parameters, and ``K`` is the number of columns.
Returns:
Matrix-multiplication result ``KFAC @ M``. Has shape ``[D, K]``.
Matrix in list format. Each entry has the same shape as a parameter with
an additional leading dimension of size ``K`` for the columns, i.e.
``[(K,) + p1.shape, (K,) + p2.shape, ...]``.
"""
num_vectors = M.shape[1]
# split parameter blocks
dims = [p.numel() for p in self._params]
result = M.split(dims)
# column-index first + unflatten parameter dimension
shapes = [(num_vectors,) + p.shape for p in self._params]
return [res.T.reshape(shape) for res, shape in zip(result, shapes)]

def _check_input_type_and_preprocess(
self, M_torch: Union[Tensor, List[Tensor]]
) -> Tuple[bool, List[Tensor]]:
"""Check input type and maybe preprocess to list format.
Check whether the input is a tensor or a list of tensors. If it is a tensor,
preprocess to list format.
Args:
M_torch: Input to check.
Returns:
``True`` if the input is a tensor, ``False`` if it is a list of tensors.
Raises:
ValueError: If the input is a list of tensors that have a different number
of columns.
ValueError: If the input is a list of tensors that have incompatible shapes
with the parameters.
ValueError: If the input is a tensor and has the wrong shape.
ValueError: If the input is a tensor and its shape is incompatible with the
KFAC approximation's shape.
"""
if isinstance(M_torch, list):
return_tensor = False
if len(M_torch) != len(self._params):
raise ValueError(
"Number of input tensors must match the number of parameter tensors."
)
column_values = {len(M) for M in M_torch}
if len(column_values) != 1:
raise ValueError(
"Number of columns must be equal for all tensors. "
f"Got {column_values}."
)
K = column_values.pop()
for M, p in zip(M_torch, self._params):
if M.shape != (K,) + p.shape:
raise ValueError(
"All input tensors must have (K, ) + p.shape. "
f"Got {M.shape}, but expected {(K,) + p.shape}."
)
else:
return_tensor = True
if M_torch.ndim != 2:
raise ValueError(f"expected 2-d tensor, not {M_torch.ndim}-d")
if M_torch.shape[0] != self.shape[1]:
raise ValueError(f"dimension mismatch: {self.shape}, {M_torch.shape}")
M_torch = self._torch_preprocess(M_torch)
return return_tensor, M_torch

def torch_matmat(
self, M_torch: Union[Tensor, List[Tensor]]
) -> Union[Tensor, List[Tensor]]:
"""Apply KFAC to a matrix (multiple vectors) in PyTorch.
This allows for matrix-matrix products with the KFAC approximation in PyTorch
without converting tensors to numpy arrays, which avoids unnecessary
device transfers when working with GPUs and flattening/concatenating.
Args:
M_torch: Matrix for multiplication. If list of tensors, each entry has the
same shape as a parameter with an additional leading dimension of size
``K`` for the columns, i.e. ``[(K,) + p1.shape), (K,) + p2.shape, ...]``.
If tensor, has shape ``[D, K]`` with some ``K``.
Returns:
Matrix-multiplication result ``KFAC @ M``. Return type is the same as the
type of the input. If list of tensors, each entry has the same shape as a
parameter with an additional leading dimension of size ``K`` for the columns,
i.e. ``[(K,) + p1.shape, (K,) + p2.shape, ...]``. If tensor, has shape
``[D, K]`` with some ``K``.
"""
return_tensor, M_torch = self._check_input_type_and_preprocess(M_torch)
if not self._input_covariances and not self._gradient_covariances:
self._compute_kfac()

M_torch = super()._preprocess(M)

for mod_name, param_pos in self._mapping.items():
# bias and weights are treated jointly
if (
Expand Down Expand Up @@ -295,6 +378,55 @@ def _matmat(self, M: ndarray) -> ndarray:
"j k,v k ... -> v j ...",
)

if return_tensor:
M_torch = cat([rearrange(M, "k ... -> (...) k") for M in M_torch])

return M_torch

def torch_matvec(
self, v_torch: Union[Tensor, List[Tensor]]
) -> Union[Tensor, List[Tensor]]:
"""Apply KFAC to a vector in PyTorch.
This allows for matrix-vector products with the KFAC approximation in PyTorch
without converting tensors to numpy arrays, which avoids unnecessary
device transfers when working with GPUs and flattening/concatenating.
Args:
v_torch: Vector for multiplication. If list of tensors, each entry has the
same shape as a parameter, i.e. ``[p1.shape, p2.shape, ...]``.
If tensor, has shape ``[D]``.
Returns:
Matrix-multiplication result ``KFAC @ v``. Return type is the same as the
type of the input. If list of tensors, each entry has the same shape as a
parameter, i.e. ``[p1.shape, p2.shape, ...]``. If tensor, has shape ``[D]``.
Raises:
ValueError: If the input tensor has the wrong data type.
"""
if isinstance(v_torch, list):
v_torch = [v_torch_i.unsqueeze(0) for v_torch_i in v_torch]
result = self.torch_matmat(v_torch)
return [res.squeeze(0) for res in result]
elif isinstance(v_torch, Tensor):
return self.torch_matmat(v_torch.unsqueeze(-1)).squeeze(-1)
else:
raise ValueError(
f"Invalid input type: {type(v_torch)}. Expected list of tensors or tensor."
)

def _matmat(self, M: ndarray) -> ndarray:
"""Apply KFAC to a matrix (multiple vectors).
Args:
M: Matrix for multiplication. Has shape ``[D, K]`` with some ``K``.
Returns:
Matrix-multiplication result ``KFAC @ M``. Has shape ``[D, K]``.
"""
M_torch = super()._preprocess(M)
M_torch = self.torch_matmat(M_torch)
return self._postprocess(M_torch)

def _adjoint(self) -> KFACLinearOperator:
Expand All @@ -308,7 +440,7 @@ def _adjoint(self) -> KFACLinearOperator:
return self

def _compute_kfac(self):
"""Compute and cache KFAC's Kronecker factors for future ``matvec``s."""
"""Compute and cache KFAC's Kronecker factors for future ``matmat``s."""
# install forward and backward hooks
hook_handles: List[RemovableHandle] = []

Expand Down
90 changes: 89 additions & 1 deletion test/test_inverse.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Contains tests for ``curvlinops/inverse``."""

import torch
from einops import rearrange
from numpy import array, eye, random
from numpy.linalg import eigh, inv
from pytest import mark, raises
Expand Down Expand Up @@ -156,7 +157,6 @@ def test_KFAC_inverse_damped_matmat(
loss_average=loss_average,
separate_weight_and_bias=separate_weight_and_bias,
)
KFAC._compute_kfac()

# add damping manually
for aaT in KFAC._input_covariances.values():
Expand Down Expand Up @@ -185,3 +185,91 @@ def test_KFAC_inverse_damped_matmat(
# test that the cache is empty
assert len(inv_KFAC._inverse_input_covariances) == 0
assert len(inv_KFAC._inverse_gradient_covariances) == 0


def test_KFAC_inverse_damped_torch_matmat(case, delta: float = 1e-2):
"""Test torch matrix-matrix multiplication by an inverse damped KFAC approximation."""
model_func, loss_func, params, data = case

loss_average = "batch" if loss_func.reduction == "mean" else None
KFAC = KFACLinearOperator(
model_func,
loss_func,
params,
data,
loss_average=loss_average,
)
inv_KFAC = KFACInverseLinearOperator(KFAC, damping=(delta, delta))
device = KFAC._device
# KFAC.dtype is a numpy data type
dtype = next(KFAC._model_func.parameters()).dtype

num_vectors = 2
X = torch.rand(KFAC.shape[1], num_vectors, dtype=dtype, device=device)
inv_KFAC_X = inv_KFAC.torch_matmat(X)
assert inv_KFAC_X.dtype == X.dtype
assert inv_KFAC_X.device == X.device
assert inv_KFAC_X.shape == (KFAC.shape[0], num_vectors)
inv_KFAC_X = inv_KFAC_X.cpu().numpy()

# Test list input format
x_list = KFAC._torch_preprocess(X)
inv_KFAC_x_list = inv_KFAC.torch_matmat(x_list)
inv_KFAC_x_list = torch.cat(
[rearrange(M, "k ... -> (...) k") for M in inv_KFAC_x_list]
)
report_nonclose(inv_KFAC_X, inv_KFAC_x_list.cpu().numpy())

# Test against multiplication with dense matrix
identity = torch.eye(inv_KFAC.shape[1], dtype=dtype, device=device)
inv_KFAC_mat = inv_KFAC.torch_matmat(identity)
inv_KFAC_mat_x = inv_KFAC_mat @ X
report_nonclose(inv_KFAC_X, inv_KFAC_mat_x.cpu().numpy(), rtol=5e-4)

# Test against _matmat
kfac_x_numpy = inv_KFAC @ X.cpu().numpy()
report_nonclose(inv_KFAC_X, kfac_x_numpy)


def test_KFAC_inverse_damped_torch_matvec(case, delta: float = 1e-2):
"""Test torch matrix-vector multiplication by an inverse damped KFAC approximation."""
model_func, loss_func, params, data = case

loss_average = "batch" if loss_func.reduction == "mean" else None
KFAC = KFACLinearOperator(
model_func,
loss_func,
params,
data,
loss_average=loss_average,
)
inv_KFAC = KFACInverseLinearOperator(KFAC, damping=(delta, delta))
device = KFAC._device
# KFAC.dtype is a numpy data type
dtype = next(KFAC._model_func.parameters()).dtype

x = torch.rand(KFAC.shape[1], dtype=dtype, device=device)
inv_KFAC_x = inv_KFAC.torch_matvec(x)
assert inv_KFAC_x.dtype == x.dtype
assert inv_KFAC_x.device == x.device
assert inv_KFAC_x.shape == x.shape

# Test list input format
# split parameter blocks
dims = [p.numel() for p in KFAC._params]
split_x = x.split(dims)
# unflatten parameter dimension
assert len(split_x) == len(KFAC._params)
x_list = [res.reshape(p.shape) for res, p in zip(split_x, KFAC._params)]
inv_kfac_x_list = inv_KFAC.torch_matvec(x_list)
inv_kfac_x_list = torch.cat([rearrange(M, "... -> (...)") for M in inv_kfac_x_list])
report_nonclose(inv_KFAC_x, inv_kfac_x_list.cpu().numpy())

# Test against multiplication with dense matrix
identity = torch.eye(inv_KFAC.shape[1], dtype=dtype, device=device)
inv_KFAC_mat = inv_KFAC.torch_matmat(identity)
inv_KFAC_mat_x = inv_KFAC_mat @ x
report_nonclose(inv_KFAC_x.cpu().numpy(), inv_KFAC_mat_x.cpu().numpy(), rtol=5e-5)

# Test against _matmat
report_nonclose(inv_KFAC @ x, inv_KFAC_x.cpu().numpy())
Loading

0 comments on commit 67e043f

Please sign in to comment.