Skip to content

Commit

Permalink
[REF] Implement inverses matmat instead of matvec, remove base class
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Feb 5, 2024
1 parent 8bae5de commit ae0d25e
Showing 1 changed file with 23 additions and 33 deletions.
56 changes: 23 additions & 33 deletions curvlinops/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,15 @@
from scipy.sparse.linalg import LinearOperator, cg


class _InverseLinearOperator(LinearOperator):
"""Base class for (approximate) inverses of linear operators."""

def _matmat(self, X: ndarray) -> ndarray:
"""Matrix-matrix multiplication.
Args:
X: Matrix for multiplication.
Returns:
Matrix-multiplication result ``A⁻¹@ X``.
"""
return column_stack([self @ col for col in X.T])


class CGInverseLinearOperator(_InverseLinearOperator):
class CGInverseLinearOperator(LinearOperator):
"""Class for inverse linear operators via conjugate gradients."""

def __init__(self, A: LinearOperator):
"""Store the linear operator whose inverse should be represented.
Args:
A: Linear operator whose inverse is formed. Must be symmetric and
positive-definite
positive-definite.
"""
super().__init__(A.dtype, A.shape)
self._A = A
Expand Down Expand Up @@ -57,20 +41,24 @@ def set_cg_hyperparameters(
"atol": atol,
}

def _matvec(self, x: ndarray) -> ndarray:
"""Multiply x by the inverse of A.
def _matmat(self, M: ndarray) -> ndarray:
"""Multiply matrix ``M`` by the inverse of ``A``.
Args:
x: Vector for multiplication.
M: Matrix for multiplication.
Returns:
Result of inverse matrix-vector multiplication, ``A⁻¹ @ x``.
Result of inverse matrix-vector multiplication, ``A⁻¹ @ M``.
"""
result, _ = cg(self._A, x, **self._cg_hyperparameters)
return result
results = []
for col in M.T:
result_col, _ = cg(self._A, col, **self._cg_hyperparameters)
results.append(result_col)

return column_stack(results)

class NeumannInverseLinearOperator(_InverseLinearOperator):

class NeumannInverseLinearOperator(LinearOperator):
"""Class for inverse linear operators via truncated Neumann series.
# noqa: B950
Expand Down Expand Up @@ -171,23 +159,23 @@ def set_neumann_hyperparameters(
self._scale = scale
self._check_nan = check_nan

def _matvec(self, x: ndarray) -> ndarray:
"""Multiply x by the inverse of A.
def _matmat(self, M: ndarray) -> ndarray:
"""Multiply a matrix ``M`` by the inverse of ``A``.
Args:
x: Vector for multiplication.
M: Matrix for multiplication.
Returns:
Result of inverse matrix-vector multiplication, ``A⁻¹ @ x``.
Result of inverse matrix-matrix multiplication, ``A⁻¹ @ M``.
Raises:
ValueError: If ``NaN`` check is turned on and ``NaN``s are detected.
"""
result, v = x.copy(), x.copy()
result, V = M.copy(), M.copy()

for idx in range(self._num_terms):
v = v - self._scale * (self._A @ v)
result = result + v
M -= self._scale * (self._A @ V)
result += V

if self._check_nan and not allclose(result, result):
raise ValueError(
Expand All @@ -196,4 +184,6 @@ def _matvec(self, x: ndarray) -> ndarray:
+ " Try decreasing `scale` and read the comment on convergence."
)

return self._scale * result
result *= self._scale

return result

0 comments on commit ae0d25e

Please sign in to comment.