diff --git a/curvlinops/_torch_base.py b/curvlinops/_torch_base.py index dbb27c3..5819228 100644 --- a/curvlinops/_torch_base.py +++ b/curvlinops/_torch_base.py @@ -647,7 +647,8 @@ def _check_deterministic(self): - Two independent applications of matvec onto the same vector yield different results - - Two independent loss/gradient computations yield different results + - Two independent total loss/gradient computations yield different results + - If ``FIXED_DATA_ORDER`` is ``True`` and any mini-batch quantity differs. Note: Deterministic checks should be performed on CPU. We noticed that even when @@ -712,15 +713,28 @@ def _check_deterministic(self): def __check_deterministic_batch( self, - Xs, - ys, - predictions, - losses, - gradients, + Xs: Tuple[Union[Tensor, MutableMapping], Union[Tensor, MutableMapping]], + ys: Tuple[Tensor, Tensor], + predictions: Tuple[Tensor, Tensor], + losses: Tuple[Optional[Tensor], Optional[Tensor]], + gradients: Tuple[Optional[List[Tensor]], Optional[List[Tensor]]], rtol: float = 1e-5, atol: float = 1e-8, ): - """Check that the data loader always returns the same order of data.""" + """Compare two outputs of ``self.data_prediction_loss_gradient``. + + Args: + Xs: The two data inputs to compare. + ys: The two data targets to compare. + predictions: The two predictions to compare. + losses: The two losses to compare. + gradients: The two gradients to compare. + rtol: Relative tolerance for comparison. Default is 1e-5. + atol: Absolute tolerance for comparison. Default is 1e-8. + + Raises: + RuntimeError: If any of the pairs mismatch. + """ X1, X2 = Xs if isinstance(X1, MutableMapping) and isinstance(X2, MutableMapping): @@ -754,6 +768,14 @@ def __check_deterministic_batch( raise RuntimeError("Check for deterministic batch gradient failed.") def __check_deterministic_matvec(self, rtol: float = 1e-5, atol: float = 1e-8): + """Probe whether the linear operator's matrix-vector product is deterministic. + + Performs two sequential matrix-vector products and compares them. + + Args: + rtol: Relative tolerance for comparison. Defaults to ``1e-5``. + atol: Absolute tolerance for comparison. Defaults to ``1e-8``. + """ v = rand(self.shape[1], device=self._device, dtype=self._infer_dtype()) Av1 = self @ v Av2 = self @ v