diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 7bea6ff..a0d3745 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -262,12 +262,12 @@ def _compute_kfac(self): def _compute_loss_and_backward(self, output: Tensor, y: Tensor): """Compute the loss and the backward pass(es) required for KFAC. - + Args: output: The model's prediction :math:`\{f_\mathbf{\theta}(\mathbf{x}_n)\}_{n=1}^N`. y: The labels :math:`\{\mathbf{y}_n\}_{n=1}^N`. - + Raises: NotImplementedError: If ``fisher_type == 'type-2'`` and ``isinstance(self._loss_func, CrossEntropyLoss)``.