Skip to content

Commit

Permalink
[REF] Merge hooks into one (#65)
Browse files Browse the repository at this point in the history
* [ADD] Test to reproduce #56

* [REF] Replace module full backward hook with tensor hook on output

See
pytorch/pytorch#61519 (comment)
for details

* [REF] Merge hooks into one

* [DEL] Remove comparison of inputs

* [DEL] Remove `inputs` from state attributes
  • Loading branch information
f-dangel authored Nov 8, 2023
1 parent 5857af8 commit d2c0d2f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 42 deletions.
50 changes: 15 additions & 35 deletions singd/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ class SINGD(Optimizer):
Uses the empirical Fisher.
Note:
(Implementation concept) The optimizer installs forward and backward hooks on
known modules. These hooks compute quantities required for the pre-conditioner.
(Implementation concept) The optimizer installs a single forward hook on known
modules. During a forward pass, this hook installs a tensor hook on the layer's
output which computes the quantities required for the pre-conditioner.
During `.step`, these quantities will be flushed to update the pre-conditioner,
compute the approximate natural gradient, and update the network parameters.
Expand Down Expand Up @@ -242,9 +243,6 @@ def __init__(
# NOTE We use the module names (strings) as keys as they don't change when a
# model is loaded from a checkpoint (unlike the module objects themselves).

# temporarily stores layer inputs during a forward-backward pass
self.inputs: Dict[str, Tensor] = {}

# store matrices for the pre-conditioner
self.Ks: Dict[str, StructuredMatrix] = {}
self.Cs: Dict[str, StructuredMatrix] = {}
Expand Down Expand Up @@ -455,20 +453,6 @@ def preconditioner_dims(module: Module) -> Tuple[int, int]:
raise NotImplementedError(f"Initialization not implemented for {module}.")
return dim_K, dim_C

def _save_input(self, module: Module, inputs: Tuple[Tensor]):
"""Internally store input of a layer if triggered by update frequency.
Saves the input to `self.inputs`.
Args:
module: Layer whose input is stored.
inputs: Inputs to the layer.
"""
T = self._get_param_group_entry(module, "T")
if is_grad_enabled() and self.steps % T == 0:
module_name = self.module_names[module]
self.inputs[module_name] = inputs[0].data

def _update_preconditioner(self, module: Module):
"""Update the pre-conditioner matrices and their momenta for a layer.
Expand Down Expand Up @@ -580,27 +564,28 @@ def _register_tensor_hook_on_output_to_accumulate_H_terms(
output: The layer's output tensor.
"""
T = self._get_param_group_entry(module, "T")
if self.steps % T == 0:
tensor_hook = partial(self._accumulate_H_terms, module)
if is_grad_enabled() and self.steps % T == 0:
tensor_hook = partial(self._accumulate_H_terms, module, inputs)
output.register_hook(tensor_hook)

def _accumulate_H_terms(self, module: Module, grad_output: Tensor):
def _accumulate_H_terms(
self, module: Module, inputs: Tuple[Tensor], grad_output: Tensor
):
"""Accumulate the current mini-batch's contribution to `H_K, H_C` for a layer.
Updates the `H_K, H_C` buffers for the module.
Requires that the layer inputs have been stored in `self.inputs`.
Args:
module: Layer whose pre-conditioner is updated.
inputs: The layer's input tensors.
grad_output: The gradient w.r.t. the output.
"""
loss_average = self._get_param_group_entry(module, "loss_average")
kfac_approx = self._get_param_group_entry(module, "kfac_approx")
module_name = self.module_names[module]

# 1) PROCESS INPUTS AND GRAD_OUTPUTS
a = self.inputs.pop(module_name)
a = inputs[0].data
# Process into matrix according to kfac_approx
# For convolutions, unfold the input, for modules with bias terms, append a 1
a = process_input(a, module, kfac_approx)
Expand Down Expand Up @@ -660,16 +645,12 @@ def _install_hooks(
if isinstance(mod, self.SUPPORTED_MODULES)
and any(p.data_ptr() in param_ids for p in mod.parameters())
}
handles = []
for module in module_names:
handles.extend(
(
module.register_forward_pre_hook(self._save_input),
module.register_forward_hook(
self._register_tensor_hook_on_output_to_accumulate_H_terms
),
)
handles = [
module.register_forward_hook(
self._register_tensor_hook_on_output_to_accumulate_H_terms
)
for module in module_names
]
return module_names, handles

def _compute_natural_gradient(self, module: Module) -> Tuple[Tensor, ...]:
Expand Down Expand Up @@ -865,7 +846,6 @@ def set_current_grad_scale(self, grad_scale: float):
"m_Cs",
"H_Ks",
"H_Cs",
"inputs",
]

def state_dict(self) -> Dict[str, Any]:
Expand Down
7 changes: 0 additions & 7 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,9 @@ def compare_optimizers( # noqa: C901
rtol_hook = rtol_hook if rtol_hook is not None else rtol

if check_hook_quantities:
assert set(optim1.inputs.keys()) == set(optim2.inputs.keys())
assert set(optim1.H_Ks.keys()) == set(optim2.H_Ks.keys())
assert set(optim1.H_Cs.keys()) == set(optim2.H_Cs.keys())

for name in optim1.inputs:
inputs1, inputs2 = optim1.inputs[name], optim2.inputs[name]
report_nonclose(
inputs1, inputs2, atol=atol_hook, rtol=rtol_hook, name="inputs"
)

for name in optim1.H_Ks:
H_K1 = optim1.H_Ks[name].value.to_dense()
H_K2 = optim2.H_Ks[name].value.to_dense()
Expand Down

0 comments on commit d2c0d2f

Please sign in to comment.