Skip to content

Commit

Permalink
[ADD] Automatically implement adjoint for self-adjoint ops
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Sep 22, 2024
1 parent 07144a3 commit 9c538ea
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 13 deletions.
8 changes: 7 additions & 1 deletion curvlinops/_torch_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,14 @@ class PyTorchLinearOperator:
operator, which can be useful for interfacing with SciPy routines. To achieve this,
the functions ``_infer_device`` and ``_infer_dtype`` must be implemented.
Attributes:
SELF_ADJOINT: Whether the linear operator is self-adjoint. If ``True``,
``_adjoint`` does not need to be implemented. Default: ``False``.
"""

SELF_ADJOINT: bool = False

def __init__(
self, in_shape: List[Tuple[int, ...]], out_shape: List[Tuple[int, ...]]
):
Expand Down Expand Up @@ -106,7 +112,7 @@ def adjoint(self) -> PyTorchLinearOperator:
Returns:
The adjoint of the linear operator.
"""
return self._adjoint()
return self if self.SELF_ADJOINT else self._adjoint()

def _adjoint(self) -> PyTorchLinearOperator:
"""Adjoint of the linear operator.
Expand Down
15 changes: 3 additions & 12 deletions curvlinops/hessian.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Contains a linear operator implementation of the Hessian."""

from __future__ import annotations

from collections.abc import MutableMapping
from typing import List, Union

Expand Down Expand Up @@ -37,8 +35,11 @@ class HessianLinearOperator(CurvatureLinearOperator):
Attributes:
SUPPORTS_BLOCKS: Whether the linear operator supports block operations.
Default is ``True``.
SELF_ADJOINT: Whether the linear operator is self-adjoint (``True`` for
Hessians).
"""

SELF_ADJOINT: bool = True
SUPPORTS_BLOCKS: bool = True

def _matmat_batch(
Expand Down Expand Up @@ -81,13 +82,3 @@ def _matmat_batch(
AM_block[p][..., n].add_(col)

return AM

def _adjoint(self) -> HessianLinearOperator:
"""Return the linear operator representing the adjoint.
The Hessian is real symmetric, and hence self-adjoint.
Returns:
Self.
"""
return self

0 comments on commit 9c538ea

Please sign in to comment.