From 0527bdb8dc2aafba0ce7e7c7499c7009f1e40d0e Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Fri, 3 Nov 2023 11:10:12 -0400 Subject: [PATCH] [REF] Merge hooks into one --- singd/optim/optimizer.py | 47 +++++++++++++--------------------------- 1 file changed, 15 insertions(+), 32 deletions(-) diff --git a/singd/optim/optimizer.py b/singd/optim/optimizer.py index bddd5e3..c23422a 100644 --- a/singd/optim/optimizer.py +++ b/singd/optim/optimizer.py @@ -35,8 +35,9 @@ class SINGD(Optimizer): Uses the empirical Fisher. Note: - (Implementation concept) The optimizer installs forward and backward hooks on - known modules. These hooks compute quantities required for the pre-conditioner. + (Implementation concept) The optimizer installs a single forward hook on known + modules. During a forward pass, this hook installs a tensor hook on the layer's + output which computes the quantities required for the pre-conditioner. During `.step`, these quantities will be flushed to update the pre-conditioner, compute the approximate natural gradient, and update the network parameters. @@ -242,9 +243,6 @@ def __init__( # NOTE We use the module names (strings) as keys as they don't change when a # model is loaded from a checkpoint (unlike the module objects themselves). - # temporarily stores layer inputs during a forward-backward pass - self.inputs: Dict[str, Tensor] = {} - # store matrices for the pre-conditioner self.Ks: Dict[str, StructuredMatrix] = {} self.Cs: Dict[str, StructuredMatrix] = {} @@ -455,20 +453,6 @@ def preconditioner_dims(module: Module) -> Tuple[int, int]: raise NotImplementedError(f"Initialization not implemented for {module}.") return dim_K, dim_C - def _save_input(self, module: Module, inputs: Tuple[Tensor]): - """Internally store input of a layer if triggered by update frequency. - - Saves the input to `self.inputs`. - - Args: - module: Layer whose input is stored. - inputs: Inputs to the layer. - """ - T = self._get_param_group_entry(module, "T") - if is_grad_enabled() and self.steps % T == 0: - module_name = self.module_names[module] - self.inputs[module_name] = inputs[0].data - def _update_preconditioner(self, module: Module): """Update the pre-conditioner matrices and their momenta for a layer. @@ -580,11 +564,13 @@ def _register_tensor_hook_on_output_to_accumulate_H_terms( output: The layer's output tensor. """ T = self._get_param_group_entry(module, "T") - if self.steps % T == 0: - tensor_hook = partial(self._accumulate_H_terms, module) + if is_grad_enabled() and self.steps % T == 0: + tensor_hook = partial(self._accumulate_H_terms, module, inputs) output.register_hook(tensor_hook) - def _accumulate_H_terms(self, module: Module, grad_output: Tensor): + def _accumulate_H_terms( + self, module: Module, inputs: Tuple[Tensor], grad_output: Tensor + ): """Accumulate the current mini-batch's contribution to `H_K, H_C` for a layer. Updates the `H_K, H_C` buffers for the module. @@ -593,6 +579,7 @@ def _accumulate_H_terms(self, module: Module, grad_output: Tensor): Args: module: Layer whose pre-conditioner is updated. + inputs: The layer's input tensors. grad_output: The gradient w.r.t. the output. """ loss_average = self._get_param_group_entry(module, "loss_average") @@ -600,7 +587,7 @@ def _accumulate_H_terms(self, module: Module, grad_output: Tensor): module_name = self.module_names[module] # 1) PROCESS INPUTS AND GRAD_OUTPUTS - a = self.inputs.pop(module_name) + a = inputs[0].data # Process into matrix according to kfac_approx # For convolutions, unfold the input, for modules with bias terms, append a 1 a = process_input(a, module, kfac_approx) @@ -660,16 +647,12 @@ def _install_hooks( if isinstance(mod, self.SUPPORTED_MODULES) and any(p.data_ptr() in param_ids for p in mod.parameters()) } - handles = [] - for module in module_names: - handles.extend( - ( - module.register_forward_pre_hook(self._save_input), - module.register_forward_hook( - self._register_tensor_hook_on_output_to_accumulate_H_terms - ), - ) + handles = [ + module.register_forward_hook( + self._register_tensor_hook_on_output_to_accumulate_H_terms ) + for module in module_names + ] return module_names, handles def _compute_natural_gradient(self, module: Module) -> Tuple[Tensor, ...]: