diff --git a/curvlinops/submatrix.py b/curvlinops/submatrix.py index 7f3e260..3fda4a8 100644 --- a/curvlinops/submatrix.py +++ b/curvlinops/submatrix.py @@ -1,5 +1,7 @@ """Implements slices of linear operators.""" +from __future__ import annotations + from typing import List from numpy import column_stack, ndarray, zeros @@ -78,3 +80,13 @@ def _matmat(self, X: ndarray) -> ndarray: ``A[row_idxs, :][:, col_idxs] @ x``. Has shape ``[len(row_idxs), N]``. """ return column_stack([self @ col for col in X.T]) + + def _adjoint(self) -> SubmatrixLinearOperator: + """Return the adjoint of the sub-matrix. + + For that, we need to take the adjoint operator, and swap row and column indices. + + Returns: + The linear operator for the adjoint sub-matrix. + """ + return type(self)(self._A.adjoint(), self._col_idxs, self._row_idxs) diff --git a/setup.cfg b/setup.cfg index 3527794..75924ff 100644 --- a/setup.cfg +++ b/setup.cfg @@ -75,6 +75,7 @@ lint = # Dependencies needed to build/view the documentation (semicolon/line-separated) docs = + setuptools==69.5.1 # RTD fails with setuptools>=70, see https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/15863 transformers datasets matplotlib diff --git a/test/test_submatrix.py b/test/test_submatrix.py index 216efc3..7532905 100644 --- a/test/test_submatrix.py +++ b/test/test_submatrix.py @@ -3,7 +3,7 @@ from typing import List, Tuple from numpy import eye, ndarray, random -from pytest import fixture, raises +from pytest import fixture, mark, raises from scipy.sparse.linalg import aslinearoperator from curvlinops.examples.utils import report_nonclose @@ -34,29 +34,60 @@ def submatrix_case(request) -> Tuple[ndarray, List[int], List[int]]: return case["A_fn"](), case["row_idxs_fn"](), case["col_idxs_fn"]() -def test_SubmatrixLinearOperator__matvec(submatrix_case): +@mark.parametrize("adjoint", [False, True], ids=["", "adjoint"]) +def test_SubmatrixLinearOperator__matvec( + submatrix_case: Tuple[ndarray, List[int], List[int]], adjoint: bool +): + """Test the matrix-vector multiplication of a submatrix linear operator. + + Args: + submatrix_case: A tuple with a random matrix and two index lists. + adjoint: Whether to take the operator's adjoint before multiplying. + """ A, row_idxs, col_idxs = submatrix_case A_sub = A[row_idxs, :][:, col_idxs] A_sub_linop = SubmatrixLinearOperator(aslinearoperator(A), row_idxs, col_idxs) - x = random.rand(len(col_idxs)) + if adjoint: + A_sub = A_sub.conj().T + A_sub_linop = A_sub_linop.adjoint() + + x = random.rand(A_sub.shape[1]) A_sub_linop_x = A_sub_linop @ x - assert A_sub_linop_x.shape == (len(row_idxs),) + assert A_sub_linop_x.shape == ((len(col_idxs),) if adjoint else (len(row_idxs),)) report_nonclose(A_sub @ x, A_sub_linop_x) -def test_SubmatrixLinearOperator__matmat(submatrix_case, num_vecs: int = 3): +@mark.parametrize("adjoint", [False, True], ids=["", "adjoint"]) +def test_SubmatrixLinearOperator__matmat( + submatrix_case: Tuple[ndarray, List[int], List[int]], + adjoint: bool, + num_vecs: int = 3, +): + """Test the matrix-matrix multiplication of a submatrix linear operator. + + Args: + submatrix_case: A tuple with a random matrix and two index lists. + adjoint: Whether to take the operator's adjoint before multiplying. + num_vecs: The number of vectors to multiply. Default: ``3``. + """ A, row_idxs, col_idxs = submatrix_case A_sub = A[row_idxs, :][:, col_idxs] A_sub_linop = SubmatrixLinearOperator(aslinearoperator(A), row_idxs, col_idxs) - X = random.rand(len(col_idxs), num_vecs) + if adjoint: + A_sub = A_sub.conj().T + A_sub_linop = A_sub_linop.adjoint() + + X = random.rand(A_sub.shape[1], num_vecs) A_sub_linop_X = A_sub_linop @ X - assert A_sub_linop_X.shape == (len(row_idxs), num_vecs) + assert A_sub_linop_X.shape == ( + (len(col_idxs), num_vecs) if adjoint else (len(row_idxs), num_vecs) + ) report_nonclose(A_sub @ X, A_sub_linop_X)