diff --git a/curvlinops/inverse.py b/curvlinops/inverse.py index 6473a8e..1ccacb4 100644 --- a/curvlinops/inverse.py +++ b/curvlinops/inverse.py @@ -430,14 +430,12 @@ def _compute_inv_damped_eigenvalues( else: inv_damped_eigenvalues = {} for p_name, pos in param_pos.items(): - if p_name == "weight": - inv_damped_eigenvalues[pos] = ( - outer(ggT_eigvals, aaT_eigvals).add_(self._damping).pow_(-1) - ) - else: - inv_damped_eigenvalues[pos] = ggT_eigvals.add(self._damping).pow_( - -1 - ) + inv_damped_eigenvalues[pos] = ( + outer(ggT_eigvals, aaT_eigvals) + if p_name == "weight" + else ggT_eigvals + ) + inv_damped_eigenvalues[pos].add_(self._damping).pow_(-1) return inv_damped_eigenvalues def _compute_inverse_factors( @@ -536,8 +534,8 @@ def _compute_or_get_cached_inverse( return aaT_inv, ggT_inv, None if self._A._correct_eigenvalues: - aaT_inv = self._A._input_covariances_eigenvectors.get(name) - ggT_inv = self._A._gradient_covariances_eigenvectors.get(name) + aaT_eigenvecs = self._A._input_covariances_eigenvectors.get(name) + ggT_eigenvecs = self._A._gradient_covariances_eigenvectors.get(name) eigenvalues = self._A._corrected_eigenvalues.get(name) if isinstance(eigenvalues, dict): inv_damped_eigenvalues = {} @@ -545,16 +543,17 @@ def _compute_or_get_cached_inverse( inv_damped_eigenvalues[key] = val.add(self._damping).pow_(-1) elif isinstance(eigenvalues, Tensor): inv_damped_eigenvalues = eigenvalues.add(self._damping).pow_(-1) - else: - aaT = self._A._input_covariances.get(name) - ggT = self._A._gradient_covariances.get(name) - aaT_inv, ggT_inv, inv_damped_eigenvalues = self._compute_inverse_factors( - aaT, ggT, name - ) + return aaT_eigenvecs, ggT_eigenvecs, inv_damped_eigenvalues + + aaT = self._A._input_covariances.get(name) + ggT = self._A._gradient_covariances.get(name) + aaT_inv, ggT_inv, inv_damped_eigenvalues = self._compute_inverse_factors( + aaT, ggT, name + ) - if self._cache: - self._inverse_input_covariances[name] = aaT_inv - self._inverse_gradient_covariances[name] = ggT_inv + if self._cache: + self._inverse_input_covariances[name] = aaT_inv + self._inverse_gradient_covariances[name] = ggT_inv return aaT_inv, ggT_inv, inv_damped_eigenvalues