diff --git a/curvlinops/inverse.py b/curvlinops/inverse.py index cd5482b..1ccacb4 100644 --- a/curvlinops/inverse.py +++ b/curvlinops/inverse.py @@ -1,20 +1,16 @@ """Implements linear operator inverses.""" from math import sqrt -from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union from warnings import warn -from einops import einsum, rearrange +from einops import rearrange from numpy import allclose, column_stack, ndarray from scipy.sparse.linalg import LinearOperator, cg, lsmr from torch import Tensor, cat, cholesky_inverse, eye, float64, outer from torch.linalg import cholesky, eigh -from curvlinops.kfac import KFACLinearOperator, ParameterMatrixType - -KFACInvType = TypeVar( - "KFACInvType", Optional[Tensor], Tuple[Optional[Tensor], Optional[Tensor]] -) +from curvlinops.kfac import FactorType, KFACLinearOperator, ParameterMatrixType class _InverseLinearOperator(LinearOperator): @@ -355,6 +351,8 @@ def __init__( raise ValueError( "Heuristic and exact damping require a single damping value." ) + if self._A._correct_eigenvalues and not use_exact_damping: + raise ValueError("Only exact damping is supported for EKFAC.") self._damping = damping self._use_heuristic_damping = use_heuristic_damping @@ -362,8 +360,8 @@ def __init__( self._use_exact_damping = use_exact_damping self._cache = cache self._retry_double_precision = retry_double_precision - self._inverse_input_covariances: Dict[str, KFACInvType] = {} - self._inverse_gradient_covariances: Dict[str, KFACInvType] = {} + self._inverse_input_covariances: Dict[str, FactorType] = {} + self._inverse_gradient_covariances: Dict[str, FactorType] = {} def _compute_damping( self, aaT: Optional[Tensor], ggT: Optional[Tensor] @@ -407,19 +405,54 @@ def _damped_cholesky(self, M: Tensor, damping: float) -> Tensor: M.add(eye(M.shape[0], dtype=M.dtype, device=M.device), alpha=damping) ) + def _compute_inv_damped_eigenvalues( + self, aaT_eigvals: Tensor, ggT_eigvals: Tensor, name: str + ) -> Union[Tensor, Dict[str, Tensor]]: + """Compute the inverses of the damped eigenvalues for a given layer. + + Args: + aaT_eigvals: Eigenvalues of the input covariance matrix. + ggT_eigvals: Eigenvalues of the gradient covariance matrix. + name: Name of the layer for which to damp and invert eigenvalues. + + Returns: + Inverses of the damped eigenvalues. + """ + param_pos = self._A._mapping[name] + if ( + not self._A._separate_weight_and_bias + and "weight" in param_pos + and "bias" in param_pos + ): + inv_damped_eigenvalues = ( + outer(ggT_eigvals, aaT_eigvals).add_(self._damping).pow_(-1) + ) + else: + inv_damped_eigenvalues = {} + for p_name, pos in param_pos.items(): + inv_damped_eigenvalues[pos] = ( + outer(ggT_eigvals, aaT_eigvals) + if p_name == "weight" + else ggT_eigvals + ) + inv_damped_eigenvalues[pos].add_(self._damping).pow_(-1) + return inv_damped_eigenvalues + def _compute_inverse_factors( - self, aaT: Optional[Tensor], ggT: Optional[Tensor] - ) -> Tuple[KFACInvType, KFACInvType]: + self, aaT: Optional[Tensor], ggT: Optional[Tensor], name: str + ) -> Tuple[FactorType, FactorType, Optional[Tensor]]: """Compute the inverses of the Kronecker factors for a given layer. Args: aaT: Input covariance matrix. ``None`` for biases. ggT: Gradient covariance matrix. + name: Name of the layer for which to invert Kronecker factors. Returns: Tuple of inverses (or eigendecompositions) of the input and gradient - covariance Kronecker factors. Can be ``None`` if the input or gradient - covariance is ``None`` (e.g. the input covariances for biases). + covariance Kronecker factors and optionally eigenvalues. Can be ``None`` if + the input or gradient covariance is ``None`` (e.g. the input covariances for + biases). Raises: RuntimeError: If a Cholesky decomposition (and optionally the retry in @@ -430,7 +463,10 @@ def _compute_inverse_factors( # Kronecker-factored eigenbasis (KFE). aaT_eigvals, aaT_eigvecs = (None, None) if aaT is None else eigh(aaT) ggT_eigvals, ggT_eigvecs = (None, None) if ggT is None else eigh(ggT) - return (aaT_eigvecs, aaT_eigvals), (ggT_eigvecs, ggT_eigvals) + inv_damped_eigenvalues = self._compute_inv_damped_eigenvalues( + aaT_eigvals, ggT_eigvals, name + ) + return aaT_eigvecs, ggT_eigvecs, inv_damped_eigenvalues else: damping_aaT, damping_ggT = self._compute_damping(aaT, ggT) @@ -476,11 +512,11 @@ def _compute_inverse_factors( raise error ggT_inv = None if ggT_chol is None else cholesky_inverse(ggT_chol) - return aaT_inv, ggT_inv + return aaT_inv, ggT_inv, None def _compute_or_get_cached_inverse( self, name: str - ) -> Tuple[KFACInvType, KFACInvType]: + ) -> Tuple[FactorType, FactorType, Optional[Tensor]]: """Invert the Kronecker factors of the KFACLinearOperator or retrieve them. Args: @@ -488,117 +524,38 @@ def _compute_or_get_cached_inverse( Returns: Tuple of inverses (or eigendecompositions) of the input and gradient - covariance Kronecker factors. Can be ``None`` if the input or gradient - covariance is ``None`` (e.g. the input covariances for biases). + covariance Kronecker factors and optionally eigenvalues. Can be ``None`` if + the input or gradient covariance is ``None`` (e.g. the input covariances for + biases). """ if name in self._inverse_input_covariances: aaT_inv = self._inverse_input_covariances.get(name) ggT_inv = self._inverse_gradient_covariances.get(name) - return aaT_inv, ggT_inv + return aaT_inv, ggT_inv, None + + if self._A._correct_eigenvalues: + aaT_eigenvecs = self._A._input_covariances_eigenvectors.get(name) + ggT_eigenvecs = self._A._gradient_covariances_eigenvectors.get(name) + eigenvalues = self._A._corrected_eigenvalues.get(name) + if isinstance(eigenvalues, dict): + inv_damped_eigenvalues = {} + for key, val in eigenvalues.items(): + inv_damped_eigenvalues[key] = val.add(self._damping).pow_(-1) + elif isinstance(eigenvalues, Tensor): + inv_damped_eigenvalues = eigenvalues.add(self._damping).pow_(-1) + return aaT_eigenvecs, ggT_eigenvecs, inv_damped_eigenvalues aaT = self._A._input_covariances.get(name) ggT = self._A._gradient_covariances.get(name) - aaT_inv, ggT_inv = self._compute_inverse_factors(aaT, ggT) + aaT_inv, ggT_inv, inv_damped_eigenvalues = self._compute_inverse_factors( + aaT, ggT, name + ) if self._cache: self._inverse_input_covariances[name] = aaT_inv self._inverse_gradient_covariances[name] = ggT_inv - return aaT_inv, ggT_inv - - def _left_and_right_multiply( - self, M_joint: Tensor, aaT_inv: KFACInvType, ggT_inv: KFACInvType - ) -> Tensor: - """Left and right multiply matrix with inverse Kronecker factors. - - Args: - M_joint: Matrix for multiplication. - aaT_inv: Inverse of the input covariance Kronecker factor. ``None`` for - biases. - ggT_inv: Inverse of the gradient covariance Kronecker factor. - - Returns: - Matrix-multiplication result ``KFAC⁻¹ @ M_joint``. - """ - if self._use_exact_damping: - # Perform damped preconditioning in KFE, e.g. see equation (21) in - # https://arxiv.org/abs/2308.03296. - aaT_eigvecs, aaT_eigvals = aaT_inv - ggT_eigvecs, ggT_eigvals = ggT_inv - # Transform in eigenbasis. - M_joint = einsum( - ggT_eigvecs, M_joint, aaT_eigvecs, "i j, m i k, k l -> m j l" - ) - # Divide by damped eigenvalues to perform the inversion. - M_joint.div_(outer(ggT_eigvals, aaT_eigvals).add_(self._damping)) - # Transform back to standard basis. - M_joint = einsum( - ggT_eigvecs, M_joint, aaT_eigvecs, "i j, m j k, l k -> m i l" - ) - else: - M_joint = einsum(ggT_inv, M_joint, aaT_inv, "i j, m j k, k l -> m i l") - return M_joint - - def _separate_left_and_right_multiply( - self, - M_torch: Tensor, - param_pos: Dict[str, int], - aaT_inv: KFACInvType, - ggT_inv: KFACInvType, - ) -> Tensor: - """Multiply matrix with inverse Kronecker factors for separated weight and bias. - - Args: - M_torch: Matrix for multiplication. - param_pos: Dictionary with positions of the weight and bias parameters. - aaT_inv: Inverse of the input covariance Kronecker factor. ``None`` for - biases. - ggT_inv: Inverse of the gradient covariance Kronecker factor. - - Returns: - Matrix-multiplication result ``KFAC⁻¹ @ M_torch``. - """ - if self._use_exact_damping: - # Perform damped preconditioning in KFE, e.g. see equation (21) in - # https://arxiv.org/abs/2308.03296. - aaT_eigvecs, aaT_eigvals = aaT_inv - ggT_eigvecs, ggT_eigvals = ggT_inv - - for p_name, pos in param_pos.items(): - # for weights we need to multiply from the right with aaT - # for weights and biases we need to multiply from the left with ggT - if p_name == "weight": - M_w = rearrange(M_torch[pos], "m c_out ... -> m c_out (...)") - aaT_fac = aaT_eigvecs if self._use_exact_damping else aaT_inv - # If `use_exact_damping` is `True`, we transform to eigenbasis - M_torch[pos] = einsum(M_w, aaT_fac, "m i j, j k -> m i k") - - ggT_fac = ggT_eigvecs if self._use_exact_damping else ggT_inv - dims = ( - "m i ... -> m j ..." - if self._use_exact_damping - else " m j ... -> m i ..." - ) - # If `use_exact_damping` is `True`, we transform to eigenbasis - M_torch[pos] = einsum(ggT_fac, M_torch[pos], f"i j, {dims}") - - if self._use_exact_damping: - # Divide by damped eigenvalues to perform the inversion and transform - # back to standard basis. - if p_name == "weight": - M_torch[pos].div_( - outer(ggT_eigvals, aaT_eigvals).add_(self._damping) - ) - M_torch[pos] = einsum( - M_torch[pos], aaT_eigvecs, "m i j, k j -> m i k" - ) - else: - M_torch[pos].div_(ggT_eigvals.add_(self._damping)) - M_torch[pos] = einsum( - ggT_eigvecs, M_torch[pos], "i j, m j ... -> m i ..." - ) - - return M_torch + return aaT_inv, ggT_inv, inv_damped_eigenvalues def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType: """Apply the inverse of KFAC to a matrix (multiple vectors) in PyTorch. @@ -621,12 +578,19 @@ def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType: ``[D, K]`` with some ``K``. """ return_tensor, M_torch = self._A._check_input_type_and_preprocess(M_torch) - if not self._A._input_covariances and not self._A._gradient_covariances: + if ( + not self._A._input_covariances + and not self._A._gradient_covariances + and not self._A._input_covariances_eigenvectors + and not self._A._gradient_covariances_eigenvectors + ): self._A._compute_kfac() for mod_name, param_pos in self._A._mapping.items(): # retrieve the inverses of the Kronecker factors from cache or invert them - aaT_inv, ggT_inv = self._compute_or_get_cached_inverse(mod_name) + aaT_inv, ggT_inv, inv_damped_eigenvalues = ( + self._compute_or_get_cached_inverse(mod_name) + ) # cache the weight shape to ensure correct shapes are returned if "weight" in param_pos: weight_shape = M_torch[param_pos["weight"]].shape @@ -640,12 +604,14 @@ def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType: w_pos, b_pos = param_pos["weight"], param_pos["bias"] M_w = rearrange(M_torch[w_pos], "m c_out ... -> m c_out (...)") M_joint = cat([M_w, M_torch[b_pos].unsqueeze(2)], dim=2) - M_joint = self._left_and_right_multiply(M_joint, aaT_inv, ggT_inv) + M_joint = self._A._left_and_right_multiply( + M_joint, aaT_inv, ggT_inv, inv_damped_eigenvalues + ) w_cols = M_w.shape[2] M_torch[w_pos], M_torch[b_pos] = M_joint.split([w_cols, 1], dim=2) else: - M_torch = self._separate_left_and_right_multiply( - M_torch, param_pos, aaT_inv, ggT_inv + M_torch = self._A._separate_left_and_right_multiply( + M_torch, param_pos, aaT_inv, ggT_inv, inv_damped_eigenvalues ) # restore original shapes diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 413c86b..33d0d74 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -28,6 +28,7 @@ from numpy import ndarray from torch import Generator, Tensor, cat, device, eye, randn, stack from torch.autograd import grad +from torch.linalg import eigh from torch.nn import ( BCEWithLogitsLoss, Conv2d, @@ -50,12 +51,23 @@ # shape as the parameters, or a single matrix/vector of shape `[D, D]`/`[D]` where `D` # is the number of parameters. ParameterMatrixType = TypeVar("ParameterMatrixType", Tensor, List[Tensor]) +FactorType = TypeVar( + "FactorType", Optional[Tensor], Tuple[Optional[Tensor], Optional[Tensor]] +) class MetaEnum(EnumMeta): """Metaclass for the Enum class for desired behavior of the `in` operator.""" - def __contains__(cls, item): + def __contains__(cls, item: str) -> bool: + """Check if an item is a valid member of the Enum. + + Args: + item: The item to check. + + Returns: + ``True`` if the item is a valid member of the Enum, ``False`` otherwise. + """ try: cls(item) except ValueError: @@ -144,6 +156,7 @@ def __init__( fisher_type: str = FisherType.MC, mc_samples: int = 1, kfac_approx: str = KFACType.EXPAND, + correct_eigenvalues: bool = False, num_per_example_loss_terms: Optional[int] = None, separate_weight_and_bias: bool = True, num_data: Optional[int] = None, @@ -204,6 +217,11 @@ def __init__( See `Eschenhagen et al., 2023 `_ for an explanation of the two approximations. Defaults to ``KFACType.EXPAND``. + correct_eigenvalues: Whether to correct the eigenvalues in the KFAC + eigenbasis, as proposed in + `George et al., 2018 `_. If true, + will only store the eigendecomposition of the KFAC approximation. + Defaults to ``False``. num_per_example_loss_terms: Number of per-example loss terms, e.g., the number of tokens in a sequence. The model outputs will have ``num_data * num_per_example_loss_terms * C`` entries, where ``C`` is @@ -225,7 +243,11 @@ def __init__( Raises: RuntimeError: If the check for deterministic behavior fails. ValueError: If the loss function is not supported. + ValueError: If the Fisher type is not supported. + ValueError: If the KFAC approximation type is not supported. ValueError: If ``fisher_type != FisherType.MC`` and ``mc_samples != 1``. + NotImplementedError: If ``correct_eigenvalues`` and ``fisher_type == + FisherType.FORWARD_ONLY``. ValueError: If ``X`` is not a tensor and ``batch_size_fn`` is not specified. """ if not isinstance(loss_func, self._SUPPORTED_LOSSES): @@ -242,6 +264,10 @@ def __init__( f"Invalid mc_samples: {mc_samples}. " "Only mc_samples=1 is supported for `fisher_type != FisherType.MC`." ) + if fisher_type == FisherType.FORWARD_ONLY and correct_eigenvalues: + raise NotImplementedError( + "Correcting eigenvalues is not supported for FisherType.FORWARD_ONLY." + ) if kfac_approx not in self._SUPPORTED_KFAC_APPROX: raise ValueError( f"Invalid kfac_approx: {kfac_approx}. " @@ -254,10 +280,25 @@ def __init__( self._fisher_type = fisher_type self._mc_samples = mc_samples self._kfac_approx = kfac_approx + self._correct_eigenvalues = correct_eigenvalues + # Initialize flag which determines whether to compute the KFAC factors or the + # eigenvalue correction in the forward-backward pass(es) + self._compute_eigenvalue_correction_flag = False self._input_covariances: Dict[str, Tensor] = {} self._gradient_covariances: Dict[str, Tensor] = {} self._mapping = self.compute_parameter_mapping(params, model_func) + # Initialize the eigenvectors and eigenvalues of the Kronecker factors + self._input_covariances_eigenvectors: Dict[str, Tensor] = {} + self._input_covariances_eigenvalues: Dict[str, Tensor] = {} + self._gradient_covariances_eigenvectors: Dict[str, Tensor] = {} + self._gradient_covariances_eigenvalues: Dict[str, Tensor] = {} + + # Initialize the cache for activations + self._cached_activations: Dict[str, Tensor] = {} + # Initialize the corrected eigenvalues for EKFAC + self._corrected_eigenvalues: Dict[str, Tensor] = {} + # Properties of the full matrix KFAC approximation are initialized to `None` self._reset_matrix_properties() @@ -408,6 +449,91 @@ def _check_input_type_and_preprocess( M_torch = self._torch_preprocess(M_torch) return return_tensor, M_torch + @staticmethod + def _left_and_right_multiply( + M_joint: Tensor, + aaT: FactorType, + ggT: FactorType, + eigenvalues: Optional[Tensor], + ) -> Tensor: + """Left and right multiply matrix with Kronecker factors. + + Args: + M_joint: Matrix for multiplication. + aaT: Input covariance Kronecker factor or its eigenvectors. ``None`` for + biases. + ggT: Gradient covariance Kronecker factor or its eigenvectors. + eigenvalues: Eigenvalues of the (E)KFAC approximation when multiplying with + the eigendecomposition of the KFAC approximation. ``None`` for the + non-decomposed KFAC approximation. + + Returns: + Matrix-multiplication result ``KFAC @ M_joint``. + """ + if eigenvalues is None: + M_joint = einsum(ggT, M_joint, aaT, "i j, m j k, k l -> m i l") + else: + # Perform preconditioning in KFE, e.g. see equation (21) in + # https://arxiv.org/abs/2308.03296. + aaT_eigvecs = aaT + ggT_eigvecs = ggT + # Transform in eigenbasis. + M_joint = einsum( + ggT_eigvecs, M_joint, aaT_eigvecs, "i j, m i k, k l -> m j l" + ) + # Multiply (broadcasted) by eigenvalues. + M_joint.mul_(eigenvalues) + # Transform back to standard basis. + M_joint = einsum( + ggT_eigvecs, M_joint, aaT_eigvecs, "i j, m j k, l k -> m i l" + ) + return M_joint + + @staticmethod + def _separate_left_and_right_multiply( + M_torch: Tensor, + param_pos: Dict[str, int], + aaT: FactorType, + ggT: FactorType, + eigenvalues: Optional[Tensor], + ) -> Tensor: + """Multiply matrix with Kronecker factors for separated weight and bias. + + Args: + M_torch: Matrix for multiplication. + param_pos: Dictionary with positions of the weight and bias parameters. + aaT: Input covariance Kronecker factor or its eigenvectors. ``None`` for + biases. + ggT: Gradient covariance Kronecker factor or its eigenvectors. + eigenvalues: Eigenvalues of the (E)KFAC approximation when multiplying with + the eigendecomposition of the KFAC approximation. ``None`` for the + non-decomposed KFAC approximation. + + Returns: + Matrix-multiplication result ``KFAC @ M_torch``. + """ + for p_name, pos in param_pos.items(): + # for weights we need to multiply from the right with aaT + # for weights and biases we need to multiply from the left with ggT + if p_name == "weight": + M_w = rearrange(M_torch[pos], "m c_out ... -> m c_out (...)") + # If `eigenvalues` is not `None`, we transform to eigenbasis here + M_torch[pos] = einsum(M_w, aaT, "m i j, j k -> m i k") + + # If `eigenvalues` is not `None`, we convert to eigenbasis here + M_torch[pos] = einsum( + ggT.T if eigenvalues else ggT, M_torch[pos], "i j, m j ... -> m i ..." + ) + + if eigenvalues is not None: + # Multiply (broadcasted) by eigenvalues, convert back to original basis + M_torch[pos].mul_(eigenvalues[pos]) + if p_name == "weight": + M_torch[pos] = einsum(M_torch[pos], aaT, "m i j, k j -> m i k") + M_torch[pos] = einsum(ggT, M_torch[pos], "i j, m j ... -> m i ...") + + return M_torch + def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType: """Apply KFAC to a matrix (multiple vectors) in PyTorch. @@ -429,7 +555,12 @@ def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType: ``[D, K]`` with some ``K``. """ return_tensor, M_torch = self._check_input_type_and_preprocess(M_torch) - if not self._input_covariances and not self._gradient_covariances: + if ( + not self._input_covariances + and not self._gradient_covariances + and not self._input_covariances_eigenvectors + and not self._gradient_covariances_eigenvectors + ): self._compute_kfac() for mod_name, param_pos in self._mapping.items(): @@ -437,6 +568,16 @@ def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType: if "weight" in param_pos: weight_shape = M_torch[param_pos["weight"]].shape + # get the Kronecker factors for the current module + if self._correct_eigenvalues: + aaT = self._input_covariances_eigenvectors.get(mod_name) + ggT = self._gradient_covariances_eigenvectors.get(mod_name) + eigenvalues = self._corrected_eigenvalues[mod_name] + else: + aaT = self._input_covariances.get(mod_name) + ggT = self._gradient_covariances.get(mod_name) + eigenvalues = None + # bias and weights are treated jointly if ( not self._separate_weight_and_bias @@ -444,33 +585,15 @@ def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType: and "bias" in param_pos.keys() ): w_pos, b_pos = param_pos["weight"], param_pos["bias"] - # v denotes the free dimension for treating multiple vectors in parallel - M_w = rearrange(M_torch[w_pos], "v c_out ... -> v c_out (...)") - M_joint = cat([M_w, M_torch[b_pos].unsqueeze(-1)], dim=2) - aaT = self._input_covariances[mod_name] - ggT = self._gradient_covariances[mod_name] - M_joint = einsum(ggT, M_joint, aaT, "i j,v j k,k l -> v i l") - + M_w = rearrange(M_torch[w_pos], "m c_out ... -> m c_out (...)") + M_joint = cat([M_w, M_torch[b_pos].unsqueeze(2)], dim=2) + M_joint = self._left_and_right_multiply(M_joint, aaT, ggT, eigenvalues) w_cols = M_w.shape[2] M_torch[w_pos], M_torch[b_pos] = M_joint.split([w_cols, 1], dim=2) - - # for weights we need to multiply from the right with aaT - # for weights and biases we need to multiply from the left with ggT else: - for p_name, pos in param_pos.items(): - if p_name == "weight": - M_w = rearrange(M_torch[pos], "v c_out ... -> v c_out (...)") - M_torch[pos] = einsum( - M_w, - self._input_covariances[mod_name], - "v c_out j,j k -> v c_out k", - ) - - M_torch[pos] = einsum( - self._gradient_covariances[mod_name], - M_torch[pos], - "j k,v k ... -> v j ...", - ) + M_torch = self._separate_left_and_right_multiply( + M_torch, param_pos, aaT, ggT, eigenvalues + ) # restore original shapes if "weight" in param_pos: @@ -576,6 +699,24 @@ def _compute_kfac(self): output = self._model_func(X) self._compute_loss_and_backward(output, y) + if self._correct_eigenvalues: + # Compute the eigenvalue decomposition of the KFAC approximation + if not ( + self._input_covariances_eigenvalues + or self._gradient_covariances_eigenvalues + ): + self._compute_eigendecomposition() + + # Compute the corrected eigenvalues for the EKFAC approximation + self._compute_eigenvalue_correction_flag = True + for X, y in self._loop_over_data(desc="Eigenvalue correction"): + output = self._model_func(X) + self._compute_loss_and_backward(output, y) + self._compute_eigenvalue_correction_flag = False + + # Delete the cached activations + self._cached_activations.clear() + # clean up for handle in hook_handles: handle.remove() @@ -798,12 +939,96 @@ def _accumulate_gradient_covariance( / (self._N_data * self._mc_samples * self._num_per_example_loss_terms), }[self._loss_func.reduction] - covariance = einsum(g, g, "b i,b j->i j").mul_(correction) + if self._compute_eigenvalue_correction_flag: + # Compute the eigenvalue correction for the EKFAC approximation + self._compute_eigenvalue_correction(module_name, g, correction) + else: + # Compute and accumulate the gradient covariance + covariance = einsum(g, g, "b i, b j -> i j").mul_(correction) + self._gradient_covariances = self._set_or_add_( + self._gradient_covariances, module_name, covariance + ) - if module_name not in self._gradient_covariances: - self._gradient_covariances[module_name] = covariance + def _compute_eigenvalue_correction( + self, module_name: str, g: Tensor, correction: int + ): + r"""Compute the corrected eigenvalues for the EKFAC approximation. + + The corrected eigenvalues are computed as + :math:`\lambda_{\text{corrected}} = (Q_g^T G Q_a)^2`, where + :math:`Q_a` and :math:`Q_g` are the eigenvectors of the input and gradient + covariances, respectively, and ``G`` is the gradient matrix. The corrected + eigenvalues are used to correct the eigenvalues of the KFAC approximation + (EKFAC). + + Args: + module_name: Name of the module in the neural network. + g: The gradient w.r.t. the layer output. + correction: Correction factor for the eigenvalues. + """ + param_pos = self._mapping[module_name] + aaT_eigenvectors = self._input_covariances_eigenvectors.get(module_name) + ggT_eigenvectors = self._gradient_covariances_eigenvectors.get(module_name) + + # Compute corrected eigenvalues for EKFAC. + if ( + not self._separate_weight_and_bias + and "weight" in param_pos.keys() + and "bias" in param_pos.keys() + ): + # Compute per-example gradient using the cached activations + per_example_gradient = einsum( + g, + self._cached_activations[module_name], + "shared d_out, shared d_in -> shared d_out d_in", + ) + # Transform the per-example gradient to the eigenbasis and square it + self._corrected_eigenvalues = self._set_or_add_( + self._corrected_eigenvalues, + module_name, + einsum( + ggT_eigenvectors, + per_example_gradient, + aaT_eigenvectors, + "d_out1 d_out2, ... d_out1 d_in1, d_in1 d_in2 -> ... d_out2 d_in2", + ) + .square_() + .sum(dim=0) + .mul_(correction), + ) else: - self._gradient_covariances[module_name].add_(covariance) + if module_name not in self._corrected_eigenvalues: + self._corrected_eigenvalues[module_name] = {} + for p_name, pos in param_pos.items(): + # Compute per-example gradient using the cached activations + per_example_gradient = ( + einsum( + g, + self._cached_activations[module_name], + "shared d_out, shared d_in -> shared d_out d_in", + ) + if p_name == "weight" + else g + ) + # Transform the per-example gradient to the eigenbasis and square it + if p_name == "weight": + per_example_gradient = einsum( + per_example_gradient, + aaT_eigenvectors, + "batch d_out d_in1, d_in1 d_in2 -> batch d_out d_in2", + ) + self._corrected_eigenvalues[module_name] = self._set_or_add_( + self._corrected_eigenvalues[module_name], + pos, + einsum( + ggT_eigenvectors, + per_example_gradient, + "d_out1 d_out2, batch d_out1 ... -> batch d_out2 ...", + ) + .square_() + .sum(dim=0) + .mul_(correction), + ) def _hook_accumulate_input_covariance( self, module: Module, inputs: Tuple[Tensor], module_name: str @@ -840,7 +1065,7 @@ def _hook_accumulate_input_covariance( if self._kfac_approx == KFACType.EXPAND: # KFAC-expand approximation - scale = x.shape[1:-1].numel() # sequence length + scale = x.shape[1:-1].numel() # weight sharing dimensions size x = rearrange(x, "batch ... d_in -> (batch ...) d_in") else: # KFAC-reduce approximation @@ -855,12 +1080,43 @@ def _hook_accumulate_input_covariance( ): x = cat([x, x.new_ones(x.shape[0], 1)], dim=1) - covariance = einsum(x, x, "b i,b j -> i j").div_(self._N_data * scale) + if self._compute_eigenvalue_correction_flag: + self._cached_activations[module_name] = x + else: + # Compute and accumulate the input covariance + covariance = einsum(x, x, "b i, b j -> i j").div_(self._N_data * scale) + self._input_covariances = self._set_or_add_( + self._input_covariances, module_name, covariance + ) + + @staticmethod + def _set_or_add_( + dictionary: Dict[str, Tensor], key: str, value: Tensor + ) -> Dict[str, Tensor]: + """Set or add a value to a dictionary entry. - if module_name not in self._input_covariances: - self._input_covariances[module_name] = covariance + Args: + dictionary: The dictionary to update. + key: The key to update. + value: The value to add. + + Returns: + The updated dictionary. + + Raises: + ValueError: If the types of the value and the dictionary entry are + incompatible. + """ + if key not in dictionary: + dictionary[key] = value + elif isinstance(dictionary[key], Tensor) and isinstance(value, Tensor): + dictionary[key].add_(value) else: - self._input_covariances[module_name].add_(covariance) + raise ValueError( + "Incompatible types for addition: dictionary value of type " + f"{type(dictionary[key])} and value to be added of type {type(value)}." + ) + return dictionary @classmethod def compute_parameter_mapping( @@ -903,6 +1159,28 @@ def compute_parameter_mapping( return positions + def _compute_eigendecomposition(self) -> None: + """Compute the eigendecomposition of the KFAC approximation.""" + if not self._input_covariances and not self._gradient_covariances: + self._compute_kfac() + + for mod_name in self._mapping.keys(): + # Free up memory by deleting the Kronecker factors + aaT = self._input_covariances.pop(mod_name, None) + ggT = self._gradient_covariances.pop(mod_name, None) + + # Compute eigendecomposition of the Kronecker factors + if aaT is not None: + aaT_eigvals, aaT_eigvecs = eigh(aaT) + self._input_covariances_eigenvectors[mod_name] = aaT_eigvecs + self._input_covariances_eigenvalues[mod_name] = aaT_eigvals + del aaT + if ggT is not None: + ggT_eigvals, ggT_eigvecs = eigh(ggT) + self._gradient_covariances_eigenvectors[mod_name] = ggT_eigvecs + self._gradient_covariances_eigenvalues[mod_name] = ggT_eigvals + del ggT + @property def trace(self) -> Tensor: r"""Trace of the KFAC approximation. @@ -918,25 +1196,41 @@ def trace(self) -> Tensor: if self._trace is not None: return self._trace - if not self._input_covariances and not self._gradient_covariances: + if ( + not self._input_covariances + and not self._gradient_covariances + and not self._corrected_eigenvalues + ): self._compute_kfac() + # Initialize the trace self._trace = 0.0 - for mod_name, param_pos in self._mapping.items(): - tr_ggT = self._gradient_covariances[mod_name].trace() - if ( - not self._separate_weight_and_bias - and "weight" in param_pos.keys() - and "bias" in param_pos.keys() - ): - self._trace += self._input_covariances[mod_name].trace() * tr_ggT - else: - for p_name in param_pos.keys(): - self._trace += tr_ggT * ( - self._input_covariances[mod_name].trace() - if p_name == "weight" - else 1 - ) + + if self._correct_eigenvalues: + for corrected_eigenvalues in self._corrected_eigenvalues.values(): + if isinstance(corrected_eigenvalues, dict): + for val in corrected_eigenvalues.values(): + self._trace += val.sum() + else: + self._trace += corrected_eigenvalues.sum() + else: + # TODO: Also support the trace for eigendecomposition of KFAC + for mod_name, param_pos in self._mapping.items(): + tr_ggT = self._gradient_covariances[mod_name].trace() + if ( + not self._separate_weight_and_bias + and "weight" in param_pos.keys() + and "bias" in param_pos.keys() + ): + self._trace += self._input_covariances[mod_name].trace() * tr_ggT + else: + for p_name in param_pos.keys(): + self._trace += tr_ggT * ( + self._input_covariances[mod_name].trace() + if p_name == "weight" + else 1 + ) + return self._trace @property @@ -955,33 +1249,49 @@ def det(self) -> Tensor: if self._det is not None: return self._det - if not self._input_covariances and not self._gradient_covariances: + if ( + not self._input_covariances + and not self._gradient_covariances + and not self._corrected_eigenvalues + ): self._compute_kfac() + # Initialize the determinant self._det = 1.0 - for mod_name, param_pos in self._mapping.items(): - m = self._gradient_covariances[mod_name].shape[0] - det_ggT = self._gradient_covariances[mod_name].det() - if ( - not self._separate_weight_and_bias - and "weight" in param_pos.keys() - and "bias" in param_pos.keys() - ): - n = self._input_covariances[mod_name].shape[0] - det_aaT = self._input_covariances[mod_name].det() - self._det *= det_aaT.pow(m) * det_ggT.pow(n) - else: - for p_name in param_pos.keys(): - n = ( - self._input_covariances[mod_name].shape[0] - if p_name == "weight" - else 1 - ) - self._det *= det_ggT.pow(n) * ( - self._input_covariances[mod_name].det().pow(m) - if p_name == "weight" - else 1 - ) + + if self._correct_eigenvalues: + for corrected_eigenvalues in self._corrected_eigenvalues.values(): + if isinstance(corrected_eigenvalues, dict): + for val in corrected_eigenvalues.values(): + self._det *= val.prod() + else: + self._det *= corrected_eigenvalues.prod() + else: + # TODO: Also support the det for eigendecomposition of KFAC + for mod_name, param_pos in self._mapping.items(): + m = self._gradient_covariances[mod_name].shape[0] + det_ggT = self._gradient_covariances[mod_name].det() + if ( + not self._separate_weight_and_bias + and "weight" in param_pos.keys() + and "bias" in param_pos.keys() + ): + n = self._input_covariances[mod_name].shape[0] + det_aaT = self._input_covariances[mod_name].det() + self._det *= det_aaT.pow(m) * det_ggT.pow(n) + else: + for p_name in param_pos.keys(): + n = ( + self._input_covariances[mod_name].shape[0] + if p_name == "weight" + else 1 + ) + self._det *= det_ggT.pow(n) * ( + self._input_covariances[mod_name].det().pow(m) + if p_name == "weight" + else 1 + ) + return self._det @property @@ -1001,33 +1311,49 @@ def logdet(self) -> Tensor: if self._logdet is not None: return self._logdet - if not self._input_covariances and not self._gradient_covariances: + if ( + not self._input_covariances + and not self._gradient_covariances + and not self._corrected_eigenvalues + ): self._compute_kfac() + # Initialize the log determinant self._logdet = 0.0 - for mod_name, param_pos in self._mapping.items(): - m = self._gradient_covariances[mod_name].shape[0] - logdet_ggT = self._gradient_covariances[mod_name].logdet() - if ( - not self._separate_weight_and_bias - and "weight" in param_pos.keys() - and "bias" in param_pos.keys() - ): - n = self._input_covariances[mod_name].shape[0] - logdet_aaT = self._input_covariances[mod_name].logdet() - self._logdet += m * logdet_aaT + n * logdet_ggT - else: - for p_name in param_pos.keys(): - n = ( - self._input_covariances[mod_name].shape[0] - if p_name == "weight" - else 1 - ) - self._logdet += n * logdet_ggT + ( - m * self._input_covariances[mod_name].logdet() - if p_name == "weight" - else 0 - ) + + if self._correct_eigenvalues: + for corrected_eigenvalues in self._corrected_eigenvalues.values(): + if isinstance(corrected_eigenvalues, dict): + for val in corrected_eigenvalues.values(): + self._logdet += val.log().sum() + else: + self._logdet += corrected_eigenvalues.log().sum() + else: + # TODO: Also support the log det for eigendecomposition of KFAC + for mod_name, param_pos in self._mapping.items(): + m = self._gradient_covariances[mod_name].shape[0] + logdet_ggT = self._gradient_covariances[mod_name].logdet() + if ( + not self._separate_weight_and_bias + and "weight" in param_pos.keys() + and "bias" in param_pos.keys() + ): + n = self._input_covariances[mod_name].shape[0] + logdet_aaT = self._input_covariances[mod_name].logdet() + self._logdet += m * logdet_aaT + n * logdet_ggT + else: + for p_name in param_pos.keys(): + n = ( + self._input_covariances[mod_name].shape[0] + if p_name == "weight" + else 1 + ) + self._logdet += n * logdet_ggT + ( + m * self._input_covariances[mod_name].logdet() + if p_name == "weight" + else 0 + ) + return self._logdet @property @@ -1044,28 +1370,43 @@ def frobenius_norm(self) -> Tensor: if self._frobenius_norm is not None: return self._frobenius_norm - if not self._input_covariances and not self._gradient_covariances: + if ( + not self._input_covariances + and not self._gradient_covariances + and not self._corrected_eigenvalues + ): self._compute_kfac() + # Initialize the Frobenius norm self._frobenius_norm = 0.0 - for mod_name, param_pos in self._mapping.items(): - squared_frob_ggT = self._gradient_covariances[mod_name].square().sum() - if ( - not self._separate_weight_and_bias - and "weight" in param_pos.keys() - and "bias" in param_pos.keys() - ): - squared_frob_aaT = self._input_covariances[mod_name].square().sum() - self._frobenius_norm += squared_frob_aaT * squared_frob_ggT - else: - for p_name in param_pos.keys(): - self._frobenius_norm += squared_frob_ggT * ( - self._input_covariances[mod_name].square().sum() - if p_name == "weight" - else 1 - ) - self._frobenius_norm.sqrt_() - return self._frobenius_norm + + if self._correct_eigenvalues: + for corrected_eigenvalues in self._corrected_eigenvalues.values(): + if isinstance(corrected_eigenvalues, dict): + for val in corrected_eigenvalues.values(): + self._frobenius_norm += val.square().sum() + else: + self._frobenius_norm += corrected_eigenvalues.square().sum() + else: + # TODO: Also support the Frobenius norm for eigendecomposition of KFAC + for mod_name, param_pos in self._mapping.items(): + squared_frob_ggT = self._gradient_covariances[mod_name].square().sum() + if ( + not self._separate_weight_and_bias + and "weight" in param_pos.keys() + and "bias" in param_pos.keys() + ): + squared_frob_aaT = self._input_covariances[mod_name].square().sum() + self._frobenius_norm += squared_frob_aaT * squared_frob_ggT + else: + for p_name in param_pos.keys(): + self._frobenius_norm += squared_frob_ggT * ( + self._input_covariances[mod_name].square().sum() + if p_name == "weight" + else 1 + ) + + return self._frobenius_norm.sqrt_() def state_dict(self) -> Dict[str, Any]: """Return the state of the KFAC linear operator. @@ -1090,13 +1431,22 @@ def state_dict(self) -> Dict[str, Any]: "fisher_type": self._fisher_type, "mc_samples": self._mc_samples, "kfac_approx": self._kfac_approx, + "correct_eigenvalues": self._correct_eigenvalues, "num_per_example_loss_terms": self._num_per_example_loss_terms, "separate_weight_and_bias": self._separate_weight_and_bias, "num_data": self._N_data, # Kronecker factors (if computed) "input_covariances": self._input_covariances, "gradient_covariances": self._gradient_covariances, - # Properties (not necessarily computed) + # Kronecker factors eigendecomposition (if computed) + "input_covariances_eigenvectors": self._input_covariances_eigenvectors, + "input_covariances_eigenvalues": self._input_covariances_eigenvalues, + "gradient_covariances_eigenvectors": self._gradient_covariances_eigenvectors, + "gradient_covariances_eigenvalues": self._gradient_covariances_eigenvalues, + # Quantities for eigenvalue correction (if computed) + "cached_activations": self._cached_activations, + "corrected_eigenvalues": self._corrected_eigenvalues, + # Properties (if computed) "trace": self._trace, "det": self._det, "logdet": self._logdet, @@ -1137,6 +1487,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]): self._fisher_type = state_dict["fisher_type"] self._mc_samples = state_dict["mc_samples"] self._kfac_approx = state_dict["kfac_approx"] + self._correct_eigenvalues = state_dict["correct_eigenvalues"] self._num_per_example_loss_terms = state_dict["num_per_example_loss_terms"] self._separate_weight_and_bias = state_dict["separate_weight_and_bias"] self._N_data = state_dict["num_data"] @@ -1162,6 +1513,26 @@ def load_state_dict(self, state_dict: Dict[str, Any]): self._input_covariances = state_dict["input_covariances"] self._gradient_covariances = state_dict["gradient_covariances"] + # Set Kronecker factors eigendecomposition (if computed) + # TODO: should we check if the keys match the mapping keys? + self._input_covariances_eigenvectors = state_dict[ + "input_covariances_eigenvectors" + ] + self._input_covariances_eigenvalues = state_dict[ + "input_covariances_eigenvalues" + ] + self._gradient_covariances_eigenvectors = state_dict[ + "gradient_covariances_eigenvectors" + ] + self._gradient_covariances_eigenvalues = state_dict[ + "gradient_covariances_eigenvalues" + ] + + # Set quantities for eigenvalue correction (if computed) + # TODO: should we check if the keys match the mapping keys? + self._cached_activations = state_dict["cached_activations"] + self._corrected_eigenvalues = state_dict["corrected_eigenvalues"] + # Set properties (not necessarily computed) self._trace = state_dict["trace"] self._det = state_dict["det"] @@ -1215,6 +1586,7 @@ def from_state_dict( fisher_type=state_dict["fisher_type"], mc_samples=state_dict["mc_samples"], kfac_approx=state_dict["kfac_approx"], + correct_eigenvalues=state_dict["correct_eigenvalues"], num_per_example_loss_terms=state_dict["num_per_example_loss_terms"], separate_weight_and_bias=state_dict["separate_weight_and_bias"], num_data=state_dict["num_data"], diff --git a/test/test_inverse.py b/test/test_inverse.py index f064ccc..150abc8 100644 --- a/test/test_inverse.py +++ b/test/test_inverse.py @@ -440,6 +440,9 @@ def test_KFAC_inverse_heuristically_damped_matmat( # noqa: C901 @mark.parametrize( "separate_weight_and_bias", [True, False], ids=["separate_bias", "joint_bias"] ) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) def test_KFAC_inverse_exactly_damped_matmat( case: Tuple[ Module, @@ -450,6 +453,7 @@ def test_KFAC_inverse_exactly_damped_matmat( cache: bool, exclude: str, separate_weight_and_bias: bool, + correct_eigenvalues: bool, delta: float = 1e-2, ): """Test matrix-matrix multiplication by an inverse (exactly) damped KFAC approximation.""" @@ -479,6 +483,7 @@ def test_KFAC_inverse_exactly_damped_matmat( batch_size_fn=batch_size_fn, separate_weight_and_bias=separate_weight_and_bias, check_deterministic=False, + correct_eigenvalues=correct_eigenvalues, ) KFAC.dtype = float64 @@ -511,7 +516,7 @@ def test_KFAC_inverse_exactly_damped_matmat( report_nonclose(inv_KFAC @ X, inv_KFAC_naive @ X) assert inv_KFAC._cache == cache - if cache: + if cache and not correct_eigenvalues: # test that the cache is not empty assert len(inv_KFAC._inverse_input_covariances) > 0 assert len(inv_KFAC._inverse_gradient_covariances) > 0 @@ -521,6 +526,9 @@ def test_KFAC_inverse_exactly_damped_matmat( assert len(inv_KFAC._inverse_gradient_covariances) == 0 +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) def test_KFAC_inverse_damped_torch_matmat( case: Tuple[ Module, @@ -528,6 +536,7 @@ def test_KFAC_inverse_damped_torch_matmat( List[Parameter], Iterable[Tuple[torch.Tensor, torch.Tensor]], ], + correct_eigenvalues: bool, delta: float = 1e-2, ): """Test torch matrix-matrix multiplication by an inverse damped KFAC approximation.""" @@ -552,9 +561,14 @@ def test_KFAC_inverse_damped_torch_matmat( data, batch_size_fn=batch_size_fn, check_deterministic=False, + correct_eigenvalues=correct_eigenvalues, ) KFAC.dtype = float64 - inv_KFAC = KFACInverseLinearOperator(KFAC, damping=(delta, delta)) + inv_KFAC = KFACInverseLinearOperator( + KFAC, + damping=delta if correct_eigenvalues else (delta, delta), + use_exact_damping=True if correct_eigenvalues else False, + ) device = KFAC._device num_vectors = 2 @@ -584,6 +598,9 @@ def test_KFAC_inverse_damped_torch_matmat( report_nonclose(inv_KFAC_X, kfac_x_numpy) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) def test_KFAC_inverse_damped_torch_matvec( case: Tuple[ Module, @@ -591,6 +608,7 @@ def test_KFAC_inverse_damped_torch_matvec( List[Parameter], Iterable[Tuple[torch.Tensor, torch.Tensor]], ], + correct_eigenvalues: bool, delta: float = 1e-2, ): """Test torch matrix-vector multiplication by an inverse damped KFAC approximation.""" @@ -615,9 +633,14 @@ def test_KFAC_inverse_damped_torch_matvec( data, batch_size_fn=batch_size_fn, check_deterministic=False, + correct_eigenvalues=correct_eigenvalues, ) KFAC.dtype = float64 - inv_KFAC = KFACInverseLinearOperator(KFAC, damping=(delta, delta)) + inv_KFAC = KFACInverseLinearOperator( + KFAC, + damping=delta if correct_eigenvalues else (delta, delta), + use_exact_damping=True if correct_eigenvalues else False, + ) device = KFAC._device x = torch.rand(KFAC.shape[1], dtype=dtype, device=device) @@ -647,7 +670,10 @@ def test_KFAC_inverse_damped_torch_matvec( report_nonclose(inv_KFAC @ x.cpu().numpy(), inv_KFAC_x.cpu().numpy()) -def test_KFAC_inverse_save_and_load_state_dict(): +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) +def test_KFAC_inverse_save_and_load_state_dict(correct_eigenvalues): """Test that KFACInverseLinearOperator can be saved and loaded from state dict.""" torch.manual_seed(0) batch_size, D_in, D_out = 4, 3, 2 @@ -662,11 +688,16 @@ def test_KFAC_inverse_save_and_load_state_dict(): MSELoss(reduction="sum"), params, [(X, y)], + correct_eigenvalues=correct_eigenvalues, ) # create inverse KFAC inv_kfac = KFACInverseLinearOperator( - kfac, damping=1e-2, use_heuristic_damping=True, retry_double_precision=False + kfac, + damping=1e-2, + use_exact_damping=True if correct_eigenvalues else False, + use_heuristic_damping=False if correct_eigenvalues else True, + retry_double_precision=False, ) _ = inv_kfac @ eye(kfac.shape[1]) # to trigger inverse computation @@ -681,7 +712,9 @@ def test_KFAC_inverse_save_and_load_state_dict(): inv_kfac_wrong.load_state_dict(torch.load("inv_kfac_state_dict.pt")) # create new inverse KFAC and load state dict - inv_kfac_new = KFACInverseLinearOperator(kfac) + inv_kfac_new = KFACInverseLinearOperator( + kfac, use_exact_damping=True if correct_eigenvalues else False + ) inv_kfac_new.load_state_dict(torch.load("inv_kfac_state_dict.pt")) # clean up os.remove("inv_kfac_state_dict.pt") @@ -692,7 +725,10 @@ def test_KFAC_inverse_save_and_load_state_dict(): report_nonclose(inv_kfac @ test_vec, inv_kfac_new @ test_vec) -def test_KFAC_inverse_from_state_dict(): +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) +def test_KFAC_inverse_from_state_dict(correct_eigenvalues): """Test that KFACInverseLinearOperator can be created from state dict.""" torch.manual_seed(0) batch_size, D_in, D_out = 4, 3, 2 @@ -707,11 +743,16 @@ def test_KFAC_inverse_from_state_dict(): MSELoss(reduction="sum"), params, [(X, y)], + correct_eigenvalues=correct_eigenvalues, ) # create inverse KFAC and save state dict inv_kfac = KFACInverseLinearOperator( - kfac, damping=1e-2, use_heuristic_damping=True, retry_double_precision=False + kfac, + damping=1e-2, + use_exact_damping=True if correct_eigenvalues else False, + use_heuristic_damping=False if correct_eigenvalues else True, + retry_double_precision=False, ) state_dict = inv_kfac.state_dict() @@ -724,7 +765,10 @@ def test_KFAC_inverse_from_state_dict(): report_nonclose(inv_kfac @ test_vec, inv_kfac_new @ test_vec) -def test_torch_matvec_list_output_shapes(cnn_case): +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) +def test_torch_matvec_list_output_shapes(cnn_case, correct_eigenvalues): """Test output shapes with list input format (issue #124).""" model, loss_func, params, data, batch_size_fn = cnn_case kfac = KFACLinearOperator( @@ -733,8 +777,11 @@ def test_torch_matvec_list_output_shapes(cnn_case): params, data, batch_size_fn=batch_size_fn, + correct_eigenvalues=correct_eigenvalues, + ) + inv_kfac = KFACInverseLinearOperator( + kfac, damping=1e-2, use_exact_damping=True if correct_eigenvalues else False ) - inv_kfac = KFACInverseLinearOperator(kfac, damping=1e-2) vec = [torch.rand_like(p) for p in kfac._params] out_list = inv_kfac.torch_matvec(vec) assert len(out_list) == len(kfac._params) diff --git a/test/test_kfac.py b/test/test_kfac.py index e7f8d61..99f8b88 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -1,15 +1,16 @@ """Contains tests for ``curvlinops.kfac``.""" import os +from contextlib import nullcontext from test.cases import DEVICES, DEVICES_IDS from test.utils import ( Conv2dModel, UnetModel, WeightShareModel, binary_classification_targets, + block_diagonal, classification_targets, compare_state_dicts, - ggn_block_diagonal, regression_targets, ) from typing import Dict, Iterable, List, Tuple, Union @@ -19,7 +20,6 @@ from numpy import eye from numpy.linalg import det, norm, slogdet from pytest import mark, raises, skip -from scipy.linalg import block_diag from torch import Tensor, allclose, cat, cuda, device from torch import eye as torch_eye from torch import isinf, isnan, load, manual_seed, rand, rand_like, randperm, save @@ -35,8 +35,8 @@ Sequential, ) +from curvlinops import EFLinearOperator, GGNLinearOperator from curvlinops.examples.utils import report_nonclose -from curvlinops.gradient_moments import EFLinearOperator from curvlinops.kfac import FisherType, KFACLinearOperator, KFACType @@ -47,6 +47,9 @@ "exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"] ) @mark.parametrize("shuffle", [False, True], ids=["", "shuffled"]) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) def test_kfac_type2( kfac_exact_case: Tuple[ Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]] @@ -54,6 +57,7 @@ def test_kfac_type2( shuffle: bool, exclude: str, separate_weight_and_bias: bool, + correct_eigenvalues: bool, ): """Test the KFAC implementation against the exact GGN. @@ -65,6 +69,7 @@ def test_kfac_type2( or ``None``. separate_weight_and_bias: Whether to treat weight and bias as separate blocks in the KFAC matrix. + correct_eigenvalues: Whether EKFAC should be used. """ assert exclude in [None, "weight", "bias"] model, loss_func, params, data, batch_size_fn = kfac_exact_case @@ -77,7 +82,8 @@ def test_kfac_type2( permutation = randperm(len(params)) params = [params[i] for i in permutation] - ggn = ggn_block_diagonal( + ggn = block_diagonal( + GGNLinearOperator, model, loss_func, params, @@ -93,10 +99,11 @@ def test_kfac_type2( batch_size_fn=batch_size_fn, fisher_type=FisherType.TYPE2, separate_weight_and_bias=separate_weight_and_bias, + correct_eigenvalues=correct_eigenvalues, ) kfac_mat = kfac @ eye(kfac.shape[1]) - report_nonclose(ggn, kfac_mat) + report_nonclose(ggn, kfac_mat, atol=1e-6) # Check that input covariances were not computed if exclude == "weight": @@ -111,6 +118,9 @@ def test_kfac_type2( "exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"] ) @mark.parametrize("shuffle", [False, True], ids=["", "shuffled"]) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) def test_kfac_type2_weight_sharing( kfac_weight_sharing_exact_case: Tuple[ Union[WeightShareModel, Conv2dModel], @@ -122,6 +132,7 @@ def test_kfac_type2_weight_sharing( shuffle: bool, exclude: str, separate_weight_and_bias: bool, + correct_eigenvalues: bool, ): """Test KFAC for linear weight-sharing layers against the exact GGN. @@ -135,6 +146,7 @@ def test_kfac_type2_weight_sharing( or ``None``. separate_weight_and_bias: Whether to treat weight and bias as separate blocks in the KFAC matrix. + correct_eigenvalues: Whether EKFAC should be used. """ assert exclude in [None, "weight", "bias"] model, loss_func, params, data, batch_size_fn = kfac_weight_sharing_exact_case @@ -152,7 +164,8 @@ def test_kfac_type2_weight_sharing( permutation = randperm(len(params)) params = [params[i] for i in permutation] - ggn = ggn_block_diagonal( + ggn = block_diagonal( + GGNLinearOperator, model, loss_func, params, @@ -169,6 +182,7 @@ def test_kfac_type2_weight_sharing( fisher_type=FisherType.TYPE2, kfac_approx=setting, # choose KFAC approximation consistent with setting separate_weight_and_bias=separate_weight_and_bias, + correct_eigenvalues=correct_eigenvalues, ) kfac_mat = kfac @ eye(kfac.shape[1]) @@ -180,11 +194,15 @@ def test_kfac_type2_weight_sharing( @mark.parametrize("shuffle", [False, True], ids=["", "shuffled"]) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) def test_kfac_mc( kfac_exact_case: Tuple[ Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]] ], shuffle: bool, + correct_eigenvalues: bool, ): """Test the KFAC implementation using MC samples against the exact GGN. @@ -192,6 +210,7 @@ def test_kfac_mc( kfac_exact_case: A fixture that returns a model, loss function, list of parameters, and data. shuffle: Whether to shuffle the parameters before computing the KFAC matrix. + correct_eigenvalues: Whether EKFAC should be used. """ model, loss_func, params, data, batch_size_fn = kfac_exact_case @@ -199,8 +218,8 @@ def test_kfac_mc( permutation = randperm(len(params)) params = [params[i] for i in permutation] - ggn = ggn_block_diagonal( - model, loss_func, params, data, batch_size_fn=batch_size_fn + ggn = block_diagonal( + GGNLinearOperator, model, loss_func, params, data, batch_size_fn=batch_size_fn ) kfac = KFACLinearOperator( model, @@ -208,7 +227,9 @@ def test_kfac_mc( params, data, batch_size_fn=batch_size_fn, - mc_samples=2_000, + fisher_type=FisherType.MC, + mc_samples=3_000, + correct_eigenvalues=correct_eigenvalues, ) kfac_mat = kfac @ eye(kfac.shape[1]) @@ -220,6 +241,9 @@ def test_kfac_mc( @mark.parametrize("setting", [KFACType.EXPAND, KFACType.REDUCE]) @mark.parametrize("shuffle", [False, True], ids=["", "shuffled"]) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) def test_kfac_mc_weight_sharing( kfac_weight_sharing_exact_case: Tuple[ Union[WeightShareModel, Conv2dModel], @@ -229,6 +253,7 @@ def test_kfac_mc_weight_sharing( ], setting: str, shuffle: bool, + correct_eigenvalues: bool, ): """Test KFAC-MC for linear layers with weight sharing against the exact GGN. @@ -238,6 +263,7 @@ def test_kfac_mc_weight_sharing( setting: The weight-sharing setting to use. Can be ``KFACType.EXPAND`` or ``KFACType.REDUCE``. shuffle: Whether to shuffle the parameters before computing the KFAC matrix. + correct_eigenvalues: Whether EKFAC should be used. """ model, loss_func, params, data, batch_size_fn = kfac_weight_sharing_exact_case model.setting = setting @@ -250,8 +276,8 @@ def test_kfac_mc_weight_sharing( permutation = randperm(len(params)) params = [params[i] for i in permutation] - ggn = ggn_block_diagonal( - model, loss_func, params, data, batch_size_fn=batch_size_fn + ggn = block_diagonal( + GGNLinearOperator, model, loss_func, params, data, batch_size_fn=batch_size_fn ) kfac = KFACLinearOperator( model, @@ -260,8 +286,9 @@ def test_kfac_mc_weight_sharing( data, batch_size_fn=batch_size_fn, fisher_type=FisherType.MC, - mc_samples=2_000, + mc_samples=4_000, kfac_approx=setting, # choose KFAC approximation consistent with setting + correct_eigenvalues=correct_eigenvalues, ) kfac_mat = kfac @ eye(kfac.shape[1]) @@ -271,6 +298,9 @@ def test_kfac_mc_weight_sharing( report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) def test_kfac_one_datum( kfac_exact_one_datum_case: Tuple[ Module, @@ -278,11 +308,12 @@ def test_kfac_one_datum( List[Parameter], Iterable[Tuple[Tensor, Tensor]], ], + correct_eigenvalues: bool, ): model, loss_func, params, data, batch_size_fn = kfac_exact_one_datum_case - ggn = ggn_block_diagonal( - model, loss_func, params, data, batch_size_fn=batch_size_fn + ggn = block_diagonal( + GGNLinearOperator, model, loss_func, params, data, batch_size_fn=batch_size_fn ) kfac = KFACLinearOperator( model, @@ -291,12 +322,16 @@ def test_kfac_one_datum( data, batch_size_fn=batch_size_fn, fisher_type=FisherType.TYPE2, + correct_eigenvalues=correct_eigenvalues, ) kfac_mat = kfac @ eye(kfac.shape[1]) report_nonclose(ggn, kfac_mat) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) def test_kfac_mc_one_datum( kfac_exact_one_datum_case: Tuple[ Module, @@ -304,11 +339,12 @@ def test_kfac_mc_one_datum( List[Parameter], Iterable[Tuple[Tensor, Tensor]], ], + correct_eigenvalues: bool, ): model, loss_func, params, data, batch_size_fn = kfac_exact_one_datum_case - ggn = ggn_block_diagonal( - model, loss_func, params, data, batch_size_fn=batch_size_fn + ggn = block_diagonal( + GGNLinearOperator, model, loss_func, params, data, batch_size_fn=batch_size_fn ) kfac = KFACLinearOperator( model, @@ -316,7 +352,9 @@ def test_kfac_mc_one_datum( params, data, batch_size_fn=batch_size_fn, + fisher_type=FisherType.MC, mc_samples=11_000, + correct_eigenvalues=correct_eigenvalues, ) kfac_mat = kfac @ eye(kfac.shape[1]) @@ -326,6 +364,15 @@ def test_kfac_mc_one_datum( report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol) +@mark.parametrize( + "separate_weight_and_bias", [True, False], ids=["separate_bias", "joint_bias"] +) +@mark.parametrize( + "exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"] +) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) def test_kfac_ef_one_datum( kfac_exact_one_datum_case: Tuple[ Module, @@ -333,16 +380,25 @@ def test_kfac_ef_one_datum( List[Parameter], Iterable[Tuple[Tensor, Tensor]], ], + separate_weight_and_bias: bool, + exclude: str, + correct_eigenvalues: bool, ): model, loss_func, params, data, batch_size_fn = kfac_exact_one_datum_case - ef_blocks = [] # list of per-parameter EFs - for param in params: - ef = EFLinearOperator( - model, loss_func, [param], data, batch_size_fn=batch_size_fn - ) - ef_blocks.append(ef @ eye(ef.shape[1])) - ef = block_diag(*ef_blocks) + if exclude is not None: + names = {p.data_ptr(): name for name, p in model.named_parameters()} + params = [p for p in params if exclude not in names[p.data_ptr()]] + + ef = block_diagonal( + EFLinearOperator, + model, + loss_func, + params, + data, + batch_size_fn=batch_size_fn, + separate_weight_and_bias=separate_weight_and_bias, + ) kfac = KFACLinearOperator( model, @@ -350,11 +406,13 @@ def test_kfac_ef_one_datum( params, data, batch_size_fn=batch_size_fn, + separate_weight_and_bias=separate_weight_and_bias, fisher_type=FisherType.EMPIRICAL, + correct_eigenvalues=correct_eigenvalues, ) kfac_mat = kfac @ eye(kfac.shape[1]) - report_nonclose(ef, kfac_mat) + report_nonclose(ef, kfac_mat, atol=1e-7) @mark.parametrize("dev", DEVICES, ids=DEVICES_IDS) @@ -375,7 +433,7 @@ def test_kfac_inplace_activations(dev: device): params = list(model.parameters()) # 1) compare KFAC and GGN - ggn = ggn_block_diagonal(model, loss_func, params, data) + ggn = block_diagonal(GGNLinearOperator, model, loss_func, params, data) kfac = KFACLinearOperator(model, loss_func, params, data, mc_samples=2_000) kfac_mat = kfac @ eye(kfac.shape[1]) @@ -389,7 +447,7 @@ def test_kfac_inplace_activations(dev: device): for mod in model.modules(): if hasattr(mod, "inplace"): mod.inplace = False - ggn_no_inplace = ggn_block_diagonal(model, loss_func, params, data) + ggn_no_inplace = block_diagonal(GGNLinearOperator, model, loss_func, params, data) report_nonclose(ggn, ggn_no_inplace) @@ -400,11 +458,15 @@ def test_kfac_inplace_activations(dev: device): ) @mark.parametrize("reduction", ["mean", "sum"]) @mark.parametrize("dev", DEVICES, ids=DEVICES_IDS) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) def test_multi_dim_output( fisher_type: str, loss: Union[MSELoss, CrossEntropyLoss, BCEWithLogitsLoss], reduction: str, dev: device, + correct_eigenvalues: bool, ): """Test the KFAC implementation for >2d outputs (using a 3d and 4d output). @@ -413,6 +475,7 @@ def test_multi_dim_output( loss: The loss function to use. reduction: The reduction to use for the loss function. dev: The device to run the test on. + correct_eigenvalues: Whether EKFAC should be used. """ manual_seed(0) # set up loss function, data, and model @@ -448,13 +511,22 @@ def test_multi_dim_output( # KFAC for deep linear network with 4d input and output params = list(model.parameters()) - kfac = KFACLinearOperator( - model, - loss_func, - params, - data, - fisher_type=fisher_type, - ) + context = ( + raises(NotImplementedError, match="eigenvalues") + if correct_eigenvalues and fisher_type == FisherType.FORWARD_ONLY + else nullcontext() + ) # EKFAC for FOOF is currently not supported + with context: + kfac = KFACLinearOperator( + model, + loss_func, + params, + data, + fisher_type=fisher_type, + correct_eigenvalues=correct_eigenvalues, + ) + if correct_eigenvalues and fisher_type == FisherType.FORWARD_ONLY: + return kfac_mat = kfac @ eye(kfac.shape[1]) # KFAC for deep linear network with 4d input and equivalent 2d output @@ -479,6 +551,7 @@ def test_multi_dim_output( params_flat, data_flat, fisher_type=fisher_type, + correct_eigenvalues=correct_eigenvalues, ) kfac_flat_mat = kfac_flat @ eye(kfac_flat.shape[1]) @@ -489,11 +562,15 @@ def test_multi_dim_output( @mark.parametrize( "loss", [MSELoss, CrossEntropyLoss, BCEWithLogitsLoss], ids=["mse", "ce", "bce"] ) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) @mark.parametrize("dev", DEVICES, ids=DEVICES_IDS) -def test_expand_setting_scaling( +def test_expand_setting_scaling( # noqa: C901 fisher_type: str, loss: Union[MSELoss, CrossEntropyLoss, BCEWithLogitsLoss], dev: device, + correct_eigenvalues: bool, ): """Test KFAC for correct scaling for expand setting with mean reduction loss. @@ -503,6 +580,7 @@ def test_expand_setting_scaling( fisher_type: The type of Fisher matrix to use. loss: The loss function to use. dev: The device to run the test on. + correct_eigenvalues: Whether EKFAC should be used. """ manual_seed(0) @@ -528,14 +606,26 @@ def test_expand_setting_scaling( params = list(model.parameters()) # KFAC with sum reduction + params = list(model.parameters()) loss_func = loss(reduction="sum").to(dev) - kfac_sum = KFACLinearOperator( - model, - loss_func, - params, - data, - fisher_type=fisher_type, - ) + + context = ( + raises(NotImplementedError, match="eigenvalues") + if correct_eigenvalues and fisher_type == FisherType.FORWARD_ONLY + else nullcontext() + ) # EKFAC for FOOF is currently not supported + with context: + kfac_sum = KFACLinearOperator( + model, + loss_func, + params, + data, + fisher_type=fisher_type, + correct_eigenvalues=correct_eigenvalues, + ) + if correct_eigenvalues and fisher_type == FisherType.FORWARD_ONLY: + return + # FOOF does not scale the gradient covariances, even when using a mean reduction if fisher_type != FisherType.FORWARD_ONLY: # Simulate a mean reduction by manually scaling the gradient covariances @@ -544,8 +634,17 @@ def test_expand_setting_scaling( output_random_variable_size = 3 # MSE loss averages over number of output channels loss_term_factor *= output_random_variable_size - for ggT in kfac_sum._gradient_covariances.values(): - ggT /= kfac_sum._N_data * loss_term_factor + correction = kfac_sum._N_data * loss_term_factor + if correct_eigenvalues: + for eigenvalues in kfac_sum._corrected_eigenvalues.values(): + if isinstance(eigenvalues, dict): + for eigenvals in eigenvalues.values(): + eigenvals /= correction + else: + eigenvalues /= correction + else: + for ggT in kfac_sum._gradient_covariances.values(): + ggT /= correction kfac_simulated_mean_mat = kfac_sum @ eye(kfac_sum.shape[1]) # KFAC with mean reduction @@ -556,10 +655,11 @@ def test_expand_setting_scaling( params, data, fisher_type=fisher_type, + correct_eigenvalues=correct_eigenvalues, ) kfac_mean_mat = kfac_mean @ eye(kfac_mean.shape[1]) - report_nonclose(kfac_simulated_mean_mat, kfac_mean_mat) + report_nonclose(kfac_simulated_mean_mat, kfac_mean_mat, atol=1e-7) def test_bug_device_change_invalidates_parameter_mapping(): @@ -595,16 +695,31 @@ def test_bug_device_change_invalidates_parameter_mapping(): report_nonclose(kfac_x_gpu, kfac_x_cpu) -def test_torch_matmat(case): +@mark.parametrize( + "separate_weight_and_bias", [True, False], ids=["separate_bias", "joint_bias"] +) +@mark.parametrize( + "exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"] +) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) +def test_torch_matmat(case, separate_weight_and_bias, exclude, correct_eigenvalues): """Test that the torch_matmat method of KFACLinearOperator works.""" model, loss_func, params, data, batch_size_fn = case + if exclude is not None: + names = {p.data_ptr(): name for name, p in model.named_parameters()} + params = [p for p in params if exclude not in names[p.data_ptr()]] + kfac = KFACLinearOperator( model, loss_func, params, data, batch_size_fn=batch_size_fn, + separate_weight_and_bias=separate_weight_and_bias, + correct_eigenvalues=correct_eigenvalues, ) device = kfac._device # KFAC.dtype is a numpy data type @@ -635,16 +750,31 @@ def test_torch_matmat(case): report_nonclose(kfac_x, kfac_x_numpy, rtol=1e-4) -def test_torch_matvec(case): +@mark.parametrize( + "separate_weight_and_bias", [True, False], ids=["separate_bias", "joint_bias"] +) +@mark.parametrize( + "exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"] +) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) +def test_torch_matvec(case, separate_weight_and_bias, exclude, correct_eigenvalues): """Test that the torch_matvec method of KFACLinearOperator works.""" model, loss_func, params, data, batch_size_fn = case + if exclude is not None: + names = {p.data_ptr(): name for name, p in model.named_parameters()} + params = [p for p in params if exclude not in names[p.data_ptr()]] + kfac = KFACLinearOperator( model, loss_func, params, data, batch_size_fn=batch_size_fn, + separate_weight_and_bias=separate_weight_and_bias, + correct_eigenvalues=correct_eigenvalues, ) device = kfac._device # KFAC.dtype is a numpy data type @@ -683,7 +813,10 @@ def test_torch_matvec(case): report_nonclose(kfac_x, kfac_x_numpy) -def test_torch_matvec_list_output_shapes(cnn_case): +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) +def test_torch_matvec_list_output_shapes(cnn_case, correct_eigenvalues): """Test output shapes with list input format (issue #124).""" model, loss_func, params, data, batch_size_fn = cnn_case kfac = KFACLinearOperator( @@ -692,6 +825,7 @@ def test_torch_matvec_list_output_shapes(cnn_case): params, data, batch_size_fn=batch_size_fn, + correct_eigenvalues=correct_eigenvalues, ) vec = [rand_like(p) for p in kfac._params] out_list = kfac.torch_matvec(vec) @@ -711,7 +845,12 @@ def test_torch_matvec_list_output_shapes(cnn_case): @mark.parametrize( "exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"] ) -def test_trace(case, exclude, separate_weight_and_bias, check_deterministic): +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) +def test_trace( + case, exclude, separate_weight_and_bias, check_deterministic, correct_eigenvalues +): """Test that the trace property of KFACLinearOperator works.""" model, loss_func, params, data, batch_size_fn = case @@ -727,6 +866,7 @@ def test_trace(case, exclude, separate_weight_and_bias, check_deterministic): batch_size_fn=batch_size_fn, separate_weight_and_bias=separate_weight_and_bias, check_deterministic=check_deterministic, + correct_eigenvalues=correct_eigenvalues, ) # Check for equivalence of trace property and naive trace computation @@ -751,7 +891,12 @@ def test_trace(case, exclude, separate_weight_and_bias, check_deterministic): @mark.parametrize( "exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"] ) -def test_frobenius_norm(case, exclude, separate_weight_and_bias, check_deterministic): +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) +def test_frobenius_norm( + case, exclude, separate_weight_and_bias, check_deterministic, correct_eigenvalues +): """Test that the Frobenius norm property of KFACLinearOperator works.""" model, loss_func, params, data, batch_size_fn = case @@ -767,6 +912,7 @@ def test_frobenius_norm(case, exclude, separate_weight_and_bias, check_determini batch_size_fn=batch_size_fn, separate_weight_and_bias=separate_weight_and_bias, check_deterministic=check_deterministic, + correct_eigenvalues=correct_eigenvalues, ) # Check for equivalence of frobenius_norm property and the naive computation @@ -791,7 +937,12 @@ def test_frobenius_norm(case, exclude, separate_weight_and_bias, check_determini @mark.parametrize( "exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"] ) -def test_det(case, exclude, separate_weight_and_bias, check_deterministic): +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) +def test_det( + case, exclude, separate_weight_and_bias, check_deterministic, correct_eigenvalues +): """Test that the determinant property of KFACLinearOperator works.""" model, loss_func, params, data, batch_size_fn = case @@ -807,21 +958,32 @@ def test_det(case, exclude, separate_weight_and_bias, check_deterministic): batch_size_fn=batch_size_fn, separate_weight_and_bias=separate_weight_and_bias, check_deterministic=check_deterministic, + correct_eigenvalues=correct_eigenvalues, ) # add damping manually to avoid singular matrices if not check_deterministic: kfac._compute_kfac() - assert kfac._input_covariances or kfac._gradient_covariances + delta = 1.0 # requires much larger damping value compared to ``logdet`` - for aaT in kfac._input_covariances.values(): - aaT.add_( - torch_eye(aaT.shape[0], dtype=aaT.dtype, device=aaT.device), alpha=delta - ) - for ggT in kfac._gradient_covariances.values(): - ggT.add_( - torch_eye(ggT.shape[0], dtype=ggT.dtype, device=ggT.device), alpha=delta - ) + if correct_eigenvalues: + assert kfac._corrected_eigenvalues + for eigenvalues in kfac._corrected_eigenvalues.values(): + if isinstance(eigenvalues, dict): + for eigenvals in eigenvalues.values(): + eigenvals.add_(delta) + else: + eigenvalues.add_(delta) + else: + assert kfac._input_covariances or kfac._gradient_covariances + for aaT in kfac._input_covariances.values(): + aaT.add_( + torch_eye(aaT.shape[0], dtype=aaT.dtype, device=aaT.device), alpha=delta + ) + for ggT in kfac._gradient_covariances.values(): + ggT.add_( + torch_eye(ggT.shape[0], dtype=ggT.dtype, device=ggT.device), alpha=delta + ) # Check for equivalence of the det property and naive determinant computation determinant = kfac.det @@ -847,7 +1009,12 @@ def test_det(case, exclude, separate_weight_and_bias, check_deterministic): @mark.parametrize( "exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"] ) -def test_logdet(case, exclude, separate_weight_and_bias, check_deterministic): +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) +def test_logdet( + case, exclude, separate_weight_and_bias, check_deterministic, correct_eigenvalues +): """Test that the log determinant property of KFACLinearOperator works.""" model, loss_func, params, data, batch_size_fn = case @@ -863,21 +1030,32 @@ def test_logdet(case, exclude, separate_weight_and_bias, check_deterministic): batch_size_fn=batch_size_fn, separate_weight_and_bias=separate_weight_and_bias, check_deterministic=check_deterministic, + correct_eigenvalues=correct_eigenvalues, ) # add damping manually to avoid singular matrices if not check_deterministic: kfac._compute_kfac() - assert kfac._input_covariances or kfac._gradient_covariances + delta = 1e-3 # only requires much smaller damping value compared to ``det`` - for aaT in kfac._input_covariances.values(): - aaT.add_( - torch_eye(aaT.shape[0], dtype=aaT.dtype, device=aaT.device), alpha=delta - ) - for ggT in kfac._gradient_covariances.values(): - ggT.add_( - torch_eye(ggT.shape[0], dtype=ggT.dtype, device=ggT.device), alpha=delta - ) + if correct_eigenvalues: + assert kfac._corrected_eigenvalues + for eigenvalues in kfac._corrected_eigenvalues.values(): + if isinstance(eigenvalues, dict): + for eigenvals in eigenvalues.values(): + eigenvals.add_(delta) + else: + eigenvalues.add_(delta) + else: + assert kfac._input_covariances or kfac._gradient_covariances + for aaT in kfac._input_covariances.values(): + aaT.add_( + torch_eye(aaT.shape[0], dtype=aaT.dtype, device=aaT.device), alpha=delta + ) + for ggT in kfac._gradient_covariances.values(): + ggT.add_( + torch_eye(ggT.shape[0], dtype=ggT.dtype, device=ggT.device), alpha=delta + ) # Check for equivalence of the logdet property and naive log determinant computation log_det = kfac.logdet @@ -900,11 +1078,15 @@ def test_logdet(case, exclude, separate_weight_and_bias, check_deterministic): "exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"] ) @mark.parametrize("shuffle", [False, True], ids=["", "shuffled"]) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) def test_forward_only_fisher_type( case: Tuple[Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]], shuffle: bool, exclude: str, separate_weight_and_bias: bool, + correct_eigenvalues: bool, ): """Test the KFAC with forward-only Fisher (used for FOOF) implementation. @@ -916,6 +1098,7 @@ def test_forward_only_fisher_type( or ``None``. separate_weight_and_bias: Whether to treat weight and bias as separate blocks in the KFAC matrix. + correct_eigenvalues: Whether EKFAC should be used. """ assert exclude in [None, "weight", "bias"] model, loss_func, params, data, batch_size_fn = case @@ -946,16 +1129,25 @@ def test_forward_only_fisher_type( ) simulated_foof_mat = foof_simulated @ eye(foof_simulated.shape[1]) - # Compute KFAC with `fisher_type=FisherType.FORWARD_ONLY` - foof = KFACLinearOperator( - model, - loss_func, - params, - data, - batch_size_fn=batch_size_fn, - separate_weight_and_bias=separate_weight_and_bias, - fisher_type=FisherType.FORWARD_ONLY, - ) + # Compute KFAC with `fisher_type=FisherType.FORWARD_ONLY + context = ( + raises(NotImplementedError, match="eigenvalues") + if correct_eigenvalues + else nullcontext() + ) # EKFAC for FOOF is currently not supported + with context: + foof = KFACLinearOperator( + model, + loss_func, + params, + data, + batch_size_fn=batch_size_fn, + separate_weight_and_bias=separate_weight_and_bias, + fisher_type=FisherType.FORWARD_ONLY, + correct_eigenvalues=correct_eigenvalues, + ) + if correct_eigenvalues: + return foof_mat = foof @ eye(foof.shape[1]) # Check for equivalence @@ -1014,7 +1206,8 @@ def test_forward_only_fisher_type_exact_case( params = [params[i] for i in permutation] # Compute exact block-diagonal GGN - ggn = ggn_block_diagonal( + ggn = block_diagonal( + GGNLinearOperator, model, loss_func, params, @@ -1118,7 +1311,8 @@ def test_forward_only_fisher_type_exact_weight_sharing_case( permutation = randperm(len(params)) params = [params[i] for i in permutation] - ggn = ggn_block_diagonal( + ggn = block_diagonal( + GGNLinearOperator, model, loss_func, params, @@ -1189,7 +1383,10 @@ def test_kfac_does_affect_grad(): assert allclose(grad_before, p.grad) -def test_save_and_load_state_dict(): +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) +def test_save_and_load_state_dict(correct_eigenvalues): """Test that KFACLinearOperator can be saved and loaded from state dict.""" manual_seed(0) batch_size, D_in, D_out = 4, 3, 2 @@ -1204,6 +1401,7 @@ def test_save_and_load_state_dict(): MSELoss(reduction="sum"), params, [(X, y)], + correct_eigenvalues=correct_eigenvalues, ) # save state dict @@ -1260,7 +1458,10 @@ def test_save_and_load_state_dict(): report_nonclose(kfac @ test_vec, kfac_new @ test_vec) -def test_from_state_dict(): +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) +def test_from_state_dict(correct_eigenvalues): """Test that KFACLinearOperator can be created from state dict.""" manual_seed(0) batch_size, D_in, D_out = 4, 3, 2 @@ -1275,6 +1476,7 @@ def test_from_state_dict(): MSELoss(reduction="sum"), params, [(X, y)], + correct_eigenvalues=correct_eigenvalues, ) # save state dict diff --git a/test/utils.py b/test/utils.py index dcc656b..0975345 100644 --- a/test/utils.py +++ b/test/utils.py @@ -33,7 +33,7 @@ Upsample, ) -from curvlinops import GGNLinearOperator +from curvlinops._base import _LinearOperator def get_available_devices() -> List[device]: @@ -87,7 +87,8 @@ def regression_targets(size: Tuple[int]) -> Tensor: return rand(*size) -def ggn_block_diagonal( +def block_diagonal( + linear_operator: _LinearOperator, model: Module, loss_func: Module, params: List[Parameter], @@ -95,26 +96,29 @@ def ggn_block_diagonal( batch_size_fn: Optional[Callable[[MutableMapping], int]] = None, separate_weight_and_bias: bool = True, ) -> ndarray: - """Compute the block-diagonal GGN. + """Compute the block-diagonal of the matrix induced by a linear operator. Args: + linear_operator: The linear operator. model: The neural network. loss_func: The loss function. - params: The parameters w.r.t. which the GGN block-diagonals will be computed. + params: The parameters w.r.t. which the block-diagonal will be computed for. data: A data loader. batch_size_fn: A function that returns the batch size given a dict-like ``X``. separate_weight_and_bias: Whether to treat weight and bias of a layer as - separate blocks in the block-diagonal GGN. Default: ``True``. + separate blocks in the block-diagonal. Default: ``True``. Returns: - The block-diagonal GGN. + The block-diagonal matrix. """ - # compute the full GGN then zero out the off-diagonal blocks - ggn = GGNLinearOperator(model, loss_func, params, data, batch_size_fn=batch_size_fn) - ggn = from_numpy(ggn @ eye(ggn.shape[1])) + # compute the full matrix then zero out the off-diagonal blocks + linop = linear_operator(model, loss_func, params, data, batch_size_fn=batch_size_fn) + linop = from_numpy(linop @ eye(linop.shape[1])) sizes = [p.numel() for p in params] - # ggn_blocks[i, j] corresponds to the block of (params[i], params[j]) - ggn_blocks = [list(block.split(sizes, dim=1)) for block in ggn.split(sizes, dim=0)] + # matrix_blocks[i, j] corresponds to the block of (params[i], params[j]) + matrix_blocks = [ + list(block.split(sizes, dim=1)) for block in linop.split(sizes, dim=0) + ] # find out which blocks to keep num_params = len(params) @@ -142,10 +146,10 @@ def ggn_block_diagonal( for i, j in product(range(num_params), range(num_params)): if (i, j) not in keep: - ggn_blocks[i][j].zero_() + matrix_blocks[i][j].zero_() # concatenate all blocks - return cat([cat(row_blocks, dim=1) for row_blocks in ggn_blocks], dim=0).numpy() + return cat([cat(row_blocks, dim=1) for row_blocks in matrix_blocks], dim=0).numpy() class WeightShareModel(Sequential):