diff --git a/curvlinops/inverse.py b/curvlinops/inverse.py index cd5482b..4b563a5 100644 --- a/curvlinops/inverse.py +++ b/curvlinops/inverse.py @@ -10,11 +10,7 @@ from torch import Tensor, cat, cholesky_inverse, eye, float64, outer from torch.linalg import cholesky, eigh -from curvlinops.kfac import KFACLinearOperator, ParameterMatrixType - -KFACInvType = TypeVar( - "KFACInvType", Optional[Tensor], Tuple[Optional[Tensor], Optional[Tensor]] -) +from curvlinops.kfac import KFACLinearOperator, KFACType, ParameterMatrixType class _InverseLinearOperator(LinearOperator): @@ -355,6 +351,8 @@ def __init__( raise ValueError( "Heuristic and exact damping require a single damping value." ) + if self._A._correct_eigenvalues and not use_exact_damping: + raise ValueError("Only exact damping is supported for EKFAC.") self._damping = damping self._use_heuristic_damping = use_heuristic_damping @@ -362,8 +360,8 @@ def __init__( self._use_exact_damping = use_exact_damping self._cache = cache self._retry_double_precision = retry_double_precision - self._inverse_input_covariances: Dict[str, KFACInvType] = {} - self._inverse_gradient_covariances: Dict[str, KFACInvType] = {} + self._inverse_input_covariances: Dict[str, KFACType] = {} + self._inverse_gradient_covariances: Dict[str, KFACType] = {} def _compute_damping( self, aaT: Optional[Tensor], ggT: Optional[Tensor] @@ -408,18 +406,20 @@ def _damped_cholesky(self, M: Tensor, damping: float) -> Tensor: ) def _compute_inverse_factors( - self, aaT: Optional[Tensor], ggT: Optional[Tensor] - ) -> Tuple[KFACInvType, KFACInvType]: + self, aaT: Optional[Tensor], ggT: Optional[Tensor], name: str + ) -> Tuple[KFACType, KFACType, Optional[Tensor]]: """Compute the inverses of the Kronecker factors for a given layer. Args: aaT: Input covariance matrix. ``None`` for biases. ggT: Gradient covariance matrix. + name: Name of the layer for which to invert Kronecker factors. Returns: Tuple of inverses (or eigendecompositions) of the input and gradient - covariance Kronecker factors. Can be ``None`` if the input or gradient - covariance is ``None`` (e.g. the input covariances for biases). + covariance Kronecker factors and optionally eigenvalues. Can be ``None`` if + the input or gradient covariance is ``None`` (e.g. the input covariances for + biases). Raises: RuntimeError: If a Cholesky decomposition (and optionally the retry in @@ -430,7 +430,27 @@ def _compute_inverse_factors( # Kronecker-factored eigenbasis (KFE). aaT_eigvals, aaT_eigvecs = (None, None) if aaT is None else eigh(aaT) ggT_eigvals, ggT_eigvecs = (None, None) if ggT is None else eigh(ggT) - return (aaT_eigvecs, aaT_eigvals), (ggT_eigvecs, ggT_eigvals) + param_pos = self._A._mapping[name] + if ( + not self._A._separate_weight_and_bias + and "weight" in param_pos + and "bias" in param_pos + ): + inv_damped_eigenvalues = ( + outer(ggT_eigvals, aaT_eigvals).add_(self._damping).pow_(-1) + ) + else: + inv_damped_eigenvalues = {} + for p_name, pos in param_pos.items(): + if p_name == "weight": + inv_damped_eigenvalues[pos] = ( + outer(ggT_eigvals, aaT_eigvals).add_(self._damping).pow_(-1) + ) + else: + inv_damped_eigenvalues[pos] = ggT_eigvals.add( + self._damping + ).pow_(-1) + return aaT_eigvecs, ggT_eigvecs, inv_damped_eigenvalues else: damping_aaT, damping_ggT = self._compute_damping(aaT, ggT) @@ -476,11 +496,11 @@ def _compute_inverse_factors( raise error ggT_inv = None if ggT_chol is None else cholesky_inverse(ggT_chol) - return aaT_inv, ggT_inv + return aaT_inv, ggT_inv, None def _compute_or_get_cached_inverse( self, name: str - ) -> Tuple[KFACInvType, KFACInvType]: + ) -> Tuple[KFACType, KFACType, Optional[Tensor]]: """Invert the Kronecker factors of the KFACLinearOperator or retrieve them. Args: @@ -488,117 +508,37 @@ def _compute_or_get_cached_inverse( Returns: Tuple of inverses (or eigendecompositions) of the input and gradient - covariance Kronecker factors. Can be ``None`` if the input or gradient - covariance is ``None`` (e.g. the input covariances for biases). + covariance Kronecker factors and optionally eigenvalues. Can be ``None`` if + the input or gradient covariance is ``None`` (e.g. the input covariances for + biases). """ if name in self._inverse_input_covariances: aaT_inv = self._inverse_input_covariances.get(name) ggT_inv = self._inverse_gradient_covariances.get(name) - return aaT_inv, ggT_inv - - aaT = self._A._input_covariances.get(name) - ggT = self._A._gradient_covariances.get(name) - aaT_inv, ggT_inv = self._compute_inverse_factors(aaT, ggT) - - if self._cache: - self._inverse_input_covariances[name] = aaT_inv - self._inverse_gradient_covariances[name] = ggT_inv - - return aaT_inv, ggT_inv - - def _left_and_right_multiply( - self, M_joint: Tensor, aaT_inv: KFACInvType, ggT_inv: KFACInvType - ) -> Tensor: - """Left and right multiply matrix with inverse Kronecker factors. - - Args: - M_joint: Matrix for multiplication. - aaT_inv: Inverse of the input covariance Kronecker factor. ``None`` for - biases. - ggT_inv: Inverse of the gradient covariance Kronecker factor. - - Returns: - Matrix-multiplication result ``KFAC⁻¹ @ M_joint``. - """ - if self._use_exact_damping: - # Perform damped preconditioning in KFE, e.g. see equation (21) in - # https://arxiv.org/abs/2308.03296. - aaT_eigvecs, aaT_eigvals = aaT_inv - ggT_eigvecs, ggT_eigvals = ggT_inv - # Transform in eigenbasis. - M_joint = einsum( - ggT_eigvecs, M_joint, aaT_eigvecs, "i j, m i k, k l -> m j l" - ) - # Divide by damped eigenvalues to perform the inversion. - M_joint.div_(outer(ggT_eigvals, aaT_eigvals).add_(self._damping)) - # Transform back to standard basis. - M_joint = einsum( - ggT_eigvecs, M_joint, aaT_eigvecs, "i j, m j k, l k -> m i l" - ) + return aaT_inv, ggT_inv, None + + if self._A._correct_eigenvalues: + aaT_inv = self._A._input_covariances_eigenvectors.get(name) + ggT_inv = self._A._gradient_covariances_eigenvectors.get(name) + eigenvalues = self._A._corrected_eigenvalues.get(name) + if isinstance(eigenvalues, dict): + inv_damped_eigenvalues = {} + for key, val in eigenvalues.items(): + inv_damped_eigenvalues[key] = val.add(self._damping).pow_(-1) + elif isinstance(eigenvalues, Tensor): + inv_damped_eigenvalues = eigenvalues.add(self._damping).pow_(-1) else: - M_joint = einsum(ggT_inv, M_joint, aaT_inv, "i j, m j k, k l -> m i l") - return M_joint - - def _separate_left_and_right_multiply( - self, - M_torch: Tensor, - param_pos: Dict[str, int], - aaT_inv: KFACInvType, - ggT_inv: KFACInvType, - ) -> Tensor: - """Multiply matrix with inverse Kronecker factors for separated weight and bias. - - Args: - M_torch: Matrix for multiplication. - param_pos: Dictionary with positions of the weight and bias parameters. - aaT_inv: Inverse of the input covariance Kronecker factor. ``None`` for - biases. - ggT_inv: Inverse of the gradient covariance Kronecker factor. - - Returns: - Matrix-multiplication result ``KFAC⁻¹ @ M_torch``. - """ - if self._use_exact_damping: - # Perform damped preconditioning in KFE, e.g. see equation (21) in - # https://arxiv.org/abs/2308.03296. - aaT_eigvecs, aaT_eigvals = aaT_inv - ggT_eigvecs, ggT_eigvals = ggT_inv - - for p_name, pos in param_pos.items(): - # for weights we need to multiply from the right with aaT - # for weights and biases we need to multiply from the left with ggT - if p_name == "weight": - M_w = rearrange(M_torch[pos], "m c_out ... -> m c_out (...)") - aaT_fac = aaT_eigvecs if self._use_exact_damping else aaT_inv - # If `use_exact_damping` is `True`, we transform to eigenbasis - M_torch[pos] = einsum(M_w, aaT_fac, "m i j, j k -> m i k") - - ggT_fac = ggT_eigvecs if self._use_exact_damping else ggT_inv - dims = ( - "m i ... -> m j ..." - if self._use_exact_damping - else " m j ... -> m i ..." + aaT = self._A._input_covariances.get(name) + ggT = self._A._gradient_covariances.get(name) + aaT_inv, ggT_inv, inv_damped_eigenvalues = self._compute_inverse_factors( + aaT, ggT, name ) - # If `use_exact_damping` is `True`, we transform to eigenbasis - M_torch[pos] = einsum(ggT_fac, M_torch[pos], f"i j, {dims}") - - if self._use_exact_damping: - # Divide by damped eigenvalues to perform the inversion and transform - # back to standard basis. - if p_name == "weight": - M_torch[pos].div_( - outer(ggT_eigvals, aaT_eigvals).add_(self._damping) - ) - M_torch[pos] = einsum( - M_torch[pos], aaT_eigvecs, "m i j, k j -> m i k" - ) - else: - M_torch[pos].div_(ggT_eigvals.add_(self._damping)) - M_torch[pos] = einsum( - ggT_eigvecs, M_torch[pos], "i j, m j ... -> m i ..." - ) - return M_torch + if self._cache: + self._inverse_input_covariances[name] = aaT_inv + self._inverse_gradient_covariances[name] = ggT_inv + + return aaT_inv, ggT_inv, inv_damped_eigenvalues def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType: """Apply the inverse of KFAC to a matrix (multiple vectors) in PyTorch. @@ -621,12 +561,19 @@ def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType: ``[D, K]`` with some ``K``. """ return_tensor, M_torch = self._A._check_input_type_and_preprocess(M_torch) - if not self._A._input_covariances and not self._A._gradient_covariances: + if ( + not self._A._input_covariances + and not self._A._gradient_covariances + and not self._A._input_covariances_eigenvectors + and not self._A._gradient_covariances_eigenvectors + ): self._A._compute_kfac() for mod_name, param_pos in self._A._mapping.items(): # retrieve the inverses of the Kronecker factors from cache or invert them - aaT_inv, ggT_inv = self._compute_or_get_cached_inverse(mod_name) + aaT_inv, ggT_inv, inv_damped_eigenvalues = ( + self._compute_or_get_cached_inverse(mod_name) + ) # cache the weight shape to ensure correct shapes are returned if "weight" in param_pos: weight_shape = M_torch[param_pos["weight"]].shape @@ -640,12 +587,14 @@ def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType: w_pos, b_pos = param_pos["weight"], param_pos["bias"] M_w = rearrange(M_torch[w_pos], "m c_out ... -> m c_out (...)") M_joint = cat([M_w, M_torch[b_pos].unsqueeze(2)], dim=2) - M_joint = self._left_and_right_multiply(M_joint, aaT_inv, ggT_inv) + M_joint = self._A._left_and_right_multiply( + M_joint, aaT_inv, ggT_inv, inv_damped_eigenvalues + ) w_cols = M_w.shape[2] M_torch[w_pos], M_torch[b_pos] = M_joint.split([w_cols, 1], dim=2) else: - M_torch = self._separate_left_and_right_multiply( - M_torch, param_pos, aaT_inv, ggT_inv + M_torch = self._A._separate_left_and_right_multiply( + M_torch, param_pos, aaT_inv, ggT_inv, inv_damped_eigenvalues ) # restore original shapes