From e6ee5d6770ba5aa5b648a975facfb450210da838 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 7 Nov 2023 18:03:20 +0100 Subject: [PATCH] Clarify fisher_type docstring --- curvlinops/kfac.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 329375c..7fba18e 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -115,13 +115,14 @@ def __init__( from the parameters. Defaults to ``None``. seed: The seed for the random number generator used to draw labels from the model's predictive distribution. Defaults to ``2147483647``. - fisher_type: The type of Fisher/GGN to approximate. If ``'type-2'``, the - expectation over the model outputs is computed exactly by - computing the backward pass for each output dimension. This is - sometimes also called type-2 Fisher. If ``'mc'``, the expectation - is approximated by sampling ``mc_samples`` labels from the model's - predictive distribution. If ``'empirical'``, the empirical gradients - are used which corresponds to the uncentered gradient covariance, or + fisher_type: The type of Fisher/GGN to approximate. If 'type-2', the + exact Hessian of the loss w.r.t. the model outputs is used. This + requires as many backward passes as the output dimension, i.e. + the number of classes for classification. This is sometimes also + called type-2 Fisher. If ``'mc'``, the expectation is approximated + by sampling ``mc_samples`` labels from the model's predictive + distribution. If ``'empirical'``, the empirical gradients are + used which corresponds to the uncentered gradient covariance, or the empirical Fisher. Defaults to ``'mc'``. mc_samples: The number of Monte-Carlo samples to use per data point. Will be ignored when ``fisher_type`` is not ``'mc'``.