Skip to content

Commit

Permalink
Refactor inverse
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Sep 21, 2024
1 parent 31cab8a commit 89c814f
Showing 1 changed file with 18 additions and 19 deletions.
37 changes: 18 additions & 19 deletions curvlinops/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,14 +430,12 @@ def _compute_inv_damped_eigenvalues(
else:
inv_damped_eigenvalues = {}
for p_name, pos in param_pos.items():
if p_name == "weight":
inv_damped_eigenvalues[pos] = (
outer(ggT_eigvals, aaT_eigvals).add_(self._damping).pow_(-1)
)
else:
inv_damped_eigenvalues[pos] = ggT_eigvals.add(self._damping).pow_(
-1
)
inv_damped_eigenvalues[pos] = (
outer(ggT_eigvals, aaT_eigvals)
if p_name == "weight"
else ggT_eigvals
)
inv_damped_eigenvalues[pos].add_(self._damping).pow_(-1)
return inv_damped_eigenvalues

def _compute_inverse_factors(
Expand Down Expand Up @@ -536,25 +534,26 @@ def _compute_or_get_cached_inverse(
return aaT_inv, ggT_inv, None

if self._A._correct_eigenvalues:
aaT_inv = self._A._input_covariances_eigenvectors.get(name)
ggT_inv = self._A._gradient_covariances_eigenvectors.get(name)
aaT_eigenvecs = self._A._input_covariances_eigenvectors.get(name)
ggT_eigenvecs = self._A._gradient_covariances_eigenvectors.get(name)
eigenvalues = self._A._corrected_eigenvalues.get(name)
if isinstance(eigenvalues, dict):
inv_damped_eigenvalues = {}
for key, val in eigenvalues.items():
inv_damped_eigenvalues[key] = val.add(self._damping).pow_(-1)
elif isinstance(eigenvalues, Tensor):
inv_damped_eigenvalues = eigenvalues.add(self._damping).pow_(-1)
else:
aaT = self._A._input_covariances.get(name)
ggT = self._A._gradient_covariances.get(name)
aaT_inv, ggT_inv, inv_damped_eigenvalues = self._compute_inverse_factors(
aaT, ggT, name
)
return aaT_eigenvecs, ggT_eigenvecs, inv_damped_eigenvalues

aaT = self._A._input_covariances.get(name)
ggT = self._A._gradient_covariances.get(name)
aaT_inv, ggT_inv, inv_damped_eigenvalues = self._compute_inverse_factors(
aaT, ggT, name
)

if self._cache:
self._inverse_input_covariances[name] = aaT_inv
self._inverse_gradient_covariances[name] = ggT_inv
if self._cache:
self._inverse_input_covariances[name] = aaT_inv
self._inverse_gradient_covariances[name] = ggT_inv

return aaT_inv, ggT_inv, inv_damped_eigenvalues

Expand Down

0 comments on commit 89c814f

Please sign in to comment.