Skip to content

Commit

Permalink
[DOC] Add documentation and type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Sep 23, 2024
1 parent 4f546fc commit 87c6eba
Showing 1 changed file with 29 additions and 7 deletions.
36 changes: 29 additions & 7 deletions curvlinops/_torch_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,8 @@ def _check_deterministic(self):
- Two independent applications of matvec onto the same vector yield different
results
- Two independent loss/gradient computations yield different results
- Two independent total loss/gradient computations yield different results
- If ``FIXED_DATA_ORDER`` is ``True`` and any mini-batch quantity differs.
Note:
Deterministic checks should be performed on CPU. We noticed that even when
Expand Down Expand Up @@ -712,15 +713,28 @@ def _check_deterministic(self):

def __check_deterministic_batch(
self,
Xs,
ys,
predictions,
losses,
gradients,
Xs: Tuple[Union[Tensor, MutableMapping], Union[Tensor, MutableMapping]],
ys: Tuple[Tensor, Tensor],
predictions: Tuple[Tensor, Tensor],
losses: Tuple[Optional[Tensor], Optional[Tensor]],
gradients: Tuple[Optional[List[Tensor]], Optional[List[Tensor]]],
rtol: float = 1e-5,
atol: float = 1e-8,
):
"""Check that the data loader always returns the same order of data."""
"""Compare two outputs of ``self.data_prediction_loss_gradient``.
Args:
Xs: The two data inputs to compare.
ys: The two data targets to compare.
predictions: The two predictions to compare.
losses: The two losses to compare.
gradients: The two gradients to compare.
rtol: Relative tolerance for comparison. Default is 1e-5.
atol: Absolute tolerance for comparison. Default is 1e-8.
Raises:
RuntimeError: If any of the pairs mismatch.
"""

X1, X2 = Xs
if isinstance(X1, MutableMapping) and isinstance(X2, MutableMapping):
Expand Down Expand Up @@ -754,6 +768,14 @@ def __check_deterministic_batch(
raise RuntimeError("Check for deterministic batch gradient failed.")

def __check_deterministic_matvec(self, rtol: float = 1e-5, atol: float = 1e-8):
"""Probe whether the linear operator's matrix-vector product is deterministic.
Performs two sequential matrix-vector products and compares them.
Args:
rtol: Relative tolerance for comparison. Defaults to ``1e-5``.
atol: Absolute tolerance for comparison. Defaults to ``1e-8``.
"""
v = rand(self.shape[1], device=self._device, dtype=self._infer_dtype())
Av1 = self @ v
Av2 = self @ v
Expand Down

0 comments on commit 87c6eba

Please sign in to comment.