From 2de2ba98b2a3dc17f5ef6b2b85212e6a87a2d169 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 7 Nov 2023 15:16:32 +0100 Subject: [PATCH] Minor docstring fixes --- curvlinops/kfac.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 3ae02f2..329375c 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -117,13 +117,14 @@ def __init__( from the model's predictive distribution. Defaults to ``2147483647``. fisher_type: The type of Fisher/GGN to approximate. If ``'type-2'``, the expectation over the model outputs is computed exactly by - computing the backwward pass for each output dimension. This is + computing the backward pass for each output dimension. This is sometimes also called type-2 Fisher. If ``'mc'``, the expectation is approximated by sampling ``mc_samples`` labels from the model's predictive distribution. If ``'empirical'``, the empirical gradients are used which corresponds to the uncentered gradient covariance, or the empirical Fisher. Defaults to ``'mc'``. mc_samples: The number of Monte-Carlo samples to use per data point. + Will be ignored when ``fisher_type`` is not ``'mc'``. Defaults to ``1``. Raises: @@ -226,7 +227,6 @@ def _compute_kfac(self): NotImplementedError: If ``fisher_type == 'type-2'``. ValueError: If ``fisher_type`` is not ``'type-2'``, ``'mc'``, or ``'empirical'``. - """ # install forward and backward hooks hook_handles: List[RemovableHandle] = []