Skip to content

Commit

Permalink
Move un-scaling of H_C into the update of m_K and m_C for better nume…
Browse files Browse the repository at this point in the history
…rical stability
  • Loading branch information
runame committed Nov 4, 2023
1 parent 298ea86 commit f050778
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions singd/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,14 +488,20 @@ def _update_preconditioner(self, module: Module):
H_K: StructuredMatrix = self.H_Ks.pop(module_name)
H_C: StructuredMatrix = self.H_Cs.pop(module_name)

# un-scale `H_C = structure(C.T @ (grad_scale * g) @ (grad_scale * g).T @ C)`
# Define `grad_unscaling` to later un-scale
# `H_C = structure(C.T @ (grad_scale * g) @ (grad_scale * g).T @ C)`.
prev_grad_scale = self._get_grad_scale(self.steps - 1)
grad_scale = self._get_grad_scale(self.steps)
if grad_scale != 1.0 or prev_grad_scale != 1.0:
# In total we have to divide by `grad_scale ** 2`. The `H_C` computed
# in the backward pass was already divided by `prev_grad_scale` to avoid
# overflows. Here, we apply the remaining un-scaling
H_C.mul_(prev_grad_scale / grad_scale**2)
# In total we have to divide by `grad_scale ** 2`. The `H_C` computed
# in the backward pass was already divided by `prev_grad_scale` to avoid
# overflows. So we apply the remaining un-scaling. For increased
# numerical stability we do not scale `H_C` directly but instead
# include the un-scaling in the update of `m_K` and `m_C`.
grad_unscaling = (
prev_grad_scale / grad_scale**2
if grad_scale != 1.0 or prev_grad_scale != 1.0
else 1.0
)

# 1) COMPUTE UPDATE
K_tK = K.from_inner()
Expand All @@ -517,13 +523,18 @@ def _update_preconditioner(self, module: Module):

# step for m_K
new_m_K = K.zeros(dim_K, dtype=dtype_K, device=dev)
new_m_K.add_(H_K, alpha=1.0 if kfac_like else H_C.average_trace())
new_m_K.add_(
H_K, alpha=1.0 if kfac_like else grad_unscaling * H_C.average_trace()
)
new_m_K.add_(K_tK, alpha=damping * (1.0 if kfac_like else C_tC.average_trace()))
new_m_K.diag_add_(-1.0).mul_(scale)

# step for m_C
new_m_C = C.zeros(dim_C, dtype=dtype_C, device=dev)
new_m_C.add_(H_C, alpha=1.0 if kfac_like else H_K.average_trace())
new_m_C.add_(
H_C,
alpha=grad_unscaling if kfac_like else grad_unscaling * H_K.average_trace(),
)
new_m_C.add_(C_tC, alpha=damping * (1.0 if kfac_like else K_tK.average_trace()))
new_m_C.diag_add_(-1.0).mul_(scale)

Expand Down

0 comments on commit f050778

Please sign in to comment.