Skip to content

Commit

Permalink
Clarify fisher_type docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Nov 7, 2023
1 parent 2de2ba9 commit e6ee5d6
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,14 @@ def __init__(
from the parameters. Defaults to ``None``.
seed: The seed for the random number generator used to draw labels
from the model's predictive distribution. Defaults to ``2147483647``.
fisher_type: The type of Fisher/GGN to approximate. If ``'type-2'``, the
expectation over the model outputs is computed exactly by
computing the backward pass for each output dimension. This is
sometimes also called type-2 Fisher. If ``'mc'``, the expectation
is approximated by sampling ``mc_samples`` labels from the model's
predictive distribution. If ``'empirical'``, the empirical gradients
are used which corresponds to the uncentered gradient covariance, or
fisher_type: The type of Fisher/GGN to approximate. If 'type-2', the
exact Hessian of the loss w.r.t. the model outputs is used. This
requires as many backward passes as the output dimension, i.e.
the number of classes for classification. This is sometimes also
called type-2 Fisher. If ``'mc'``, the expectation is approximated
by sampling ``mc_samples`` labels from the model's predictive
distribution. If ``'empirical'``, the empirical gradients are
used which corresponds to the uncentered gradient covariance, or
the empirical Fisher. Defaults to ``'mc'``.
mc_samples: The number of Monte-Carlo samples to use per data point.
Will be ignored when ``fisher_type`` is not ``'mc'``.
Expand Down

0 comments on commit e6ee5d6

Please sign in to comment.