Skip to content

Commit

Permalink
[REF] Merge hooks into one
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Nov 3, 2023
1 parent d4c9921 commit 0527bdb
Showing 1 changed file with 15 additions and 32 deletions.
47 changes: 15 additions & 32 deletions singd/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ class SINGD(Optimizer):
Uses the empirical Fisher.
Note:
(Implementation concept) The optimizer installs forward and backward hooks on
known modules. These hooks compute quantities required for the pre-conditioner.
(Implementation concept) The optimizer installs a single forward hook on known
modules. During a forward pass, this hook installs a tensor hook on the layer's
output which computes the quantities required for the pre-conditioner.
During `.step`, these quantities will be flushed to update the pre-conditioner,
compute the approximate natural gradient, and update the network parameters.
Expand Down Expand Up @@ -242,9 +243,6 @@ def __init__(
# NOTE We use the module names (strings) as keys as they don't change when a
# model is loaded from a checkpoint (unlike the module objects themselves).

# temporarily stores layer inputs during a forward-backward pass
self.inputs: Dict[str, Tensor] = {}

# store matrices for the pre-conditioner
self.Ks: Dict[str, StructuredMatrix] = {}
self.Cs: Dict[str, StructuredMatrix] = {}
Expand Down Expand Up @@ -455,20 +453,6 @@ def preconditioner_dims(module: Module) -> Tuple[int, int]:
raise NotImplementedError(f"Initialization not implemented for {module}.")
return dim_K, dim_C

def _save_input(self, module: Module, inputs: Tuple[Tensor]):
"""Internally store input of a layer if triggered by update frequency.
Saves the input to `self.inputs`.
Args:
module: Layer whose input is stored.
inputs: Inputs to the layer.
"""
T = self._get_param_group_entry(module, "T")
if is_grad_enabled() and self.steps % T == 0:
module_name = self.module_names[module]
self.inputs[module_name] = inputs[0].data

def _update_preconditioner(self, module: Module):
"""Update the pre-conditioner matrices and their momenta for a layer.
Expand Down Expand Up @@ -580,11 +564,13 @@ def _register_tensor_hook_on_output_to_accumulate_H_terms(
output: The layer's output tensor.
"""
T = self._get_param_group_entry(module, "T")
if self.steps % T == 0:
tensor_hook = partial(self._accumulate_H_terms, module)
if is_grad_enabled() and self.steps % T == 0:
tensor_hook = partial(self._accumulate_H_terms, module, inputs)
output.register_hook(tensor_hook)

def _accumulate_H_terms(self, module: Module, grad_output: Tensor):
def _accumulate_H_terms(
self, module: Module, inputs: Tuple[Tensor], grad_output: Tensor
):
"""Accumulate the current mini-batch's contribution to `H_K, H_C` for a layer.
Updates the `H_K, H_C` buffers for the module.
Expand All @@ -593,14 +579,15 @@ def _accumulate_H_terms(self, module: Module, grad_output: Tensor):
Args:
module: Layer whose pre-conditioner is updated.
inputs: The layer's input tensors.
grad_output: The gradient w.r.t. the output.
"""
loss_average = self._get_param_group_entry(module, "loss_average")
kfac_approx = self._get_param_group_entry(module, "kfac_approx")
module_name = self.module_names[module]

# 1) PROCESS INPUTS AND GRAD_OUTPUTS
a = self.inputs.pop(module_name)
a = inputs[0].data
# Process into matrix according to kfac_approx
# For convolutions, unfold the input, for modules with bias terms, append a 1
a = process_input(a, module, kfac_approx)
Expand Down Expand Up @@ -660,16 +647,12 @@ def _install_hooks(
if isinstance(mod, self.SUPPORTED_MODULES)
and any(p.data_ptr() in param_ids for p in mod.parameters())
}
handles = []
for module in module_names:
handles.extend(
(
module.register_forward_pre_hook(self._save_input),
module.register_forward_hook(
self._register_tensor_hook_on_output_to_accumulate_H_terms
),
)
handles = [
module.register_forward_hook(
self._register_tensor_hook_on_output_to_accumulate_H_terms
)
for module in module_names
]
return module_names, handles

def _compute_natural_gradient(self, module: Module) -> Tuple[Tensor, ...]:
Expand Down

0 comments on commit 0527bdb

Please sign in to comment.