diff --git a/curvlinops/_torch_base.py b/curvlinops/_torch_base.py index 4937bfc..97c4aaf 100644 --- a/curvlinops/_torch_base.py +++ b/curvlinops/_torch_base.py @@ -32,8 +32,14 @@ class PyTorchLinearOperator: operator, which can be useful for interfacing with SciPy routines. To achieve this, the functions ``_infer_device`` and ``_infer_dtype`` must be implemented. + Attributes: + SELF_ADJOINT: Whether the linear operator is self-adjoint. If ``True``, + ``_adjoint`` does not need to be implemented. Default: ``False``. + """ + SELF_ADJOINT: bool = False + def __init__( self, in_shape: List[Tuple[int, ...]], out_shape: List[Tuple[int, ...]] ): @@ -106,7 +112,7 @@ def adjoint(self) -> PyTorchLinearOperator: Returns: The adjoint of the linear operator. """ - return self._adjoint() + return self if self.SELF_ADJOINT else self._adjoint() def _adjoint(self) -> PyTorchLinearOperator: """Adjoint of the linear operator. diff --git a/curvlinops/hessian.py b/curvlinops/hessian.py index e64abdd..2bae30b 100644 --- a/curvlinops/hessian.py +++ b/curvlinops/hessian.py @@ -1,7 +1,5 @@ """Contains a linear operator implementation of the Hessian.""" -from __future__ import annotations - from collections.abc import MutableMapping from typing import List, Union @@ -37,8 +35,11 @@ class HessianLinearOperator(CurvatureLinearOperator): Attributes: SUPPORTS_BLOCKS: Whether the linear operator supports block operations. Default is ``True``. + SELF_ADJOINT: Whether the linear operator is self-adjoint (``True`` for + Hessians). """ + SELF_ADJOINT: bool = True SUPPORTS_BLOCKS: bool = True def _matmat_batch( @@ -81,13 +82,3 @@ def _matmat_batch( AM_block[p][..., n].add_(col) return AM - - def _adjoint(self) -> HessianLinearOperator: - """Return the linear operator representing the adjoint. - - The Hessian is real symmetric, and hence self-adjoint. - - Returns: - Self. - """ - return self