Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REF] Merge hooks into one #65

Merged
merged 6 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 15 additions & 35 deletions singd/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ class SINGD(Optimizer):
Uses the empirical Fisher.

Note:
(Implementation concept) The optimizer installs forward and backward hooks on
known modules. These hooks compute quantities required for the pre-conditioner.
(Implementation concept) The optimizer installs a single forward hook on known
modules. During a forward pass, this hook installs a tensor hook on the layer's
output which computes the quantities required for the pre-conditioner.
During `.step`, these quantities will be flushed to update the pre-conditioner,
compute the approximate natural gradient, and update the network parameters.

Expand Down Expand Up @@ -242,9 +243,6 @@ def __init__(
# NOTE We use the module names (strings) as keys as they don't change when a
# model is loaded from a checkpoint (unlike the module objects themselves).

# temporarily stores layer inputs during a forward-backward pass
self.inputs: Dict[str, Tensor] = {}

# store matrices for the pre-conditioner
self.Ks: Dict[str, StructuredMatrix] = {}
self.Cs: Dict[str, StructuredMatrix] = {}
Expand Down Expand Up @@ -455,20 +453,6 @@ def preconditioner_dims(module: Module) -> Tuple[int, int]:
raise NotImplementedError(f"Initialization not implemented for {module}.")
return dim_K, dim_C

def _save_input(self, module: Module, inputs: Tuple[Tensor]):
"""Internally store input of a layer if triggered by update frequency.

Saves the input to `self.inputs`.

Args:
module: Layer whose input is stored.
inputs: Inputs to the layer.
"""
T = self._get_param_group_entry(module, "T")
if is_grad_enabled() and self.steps % T == 0:
module_name = self.module_names[module]
self.inputs[module_name] = inputs[0].data

def _update_preconditioner(self, module: Module):
"""Update the pre-conditioner matrices and their momenta for a layer.

Expand Down Expand Up @@ -580,27 +564,28 @@ def _register_tensor_hook_on_output_to_accumulate_H_terms(
output: The layer's output tensor.
"""
T = self._get_param_group_entry(module, "T")
if self.steps % T == 0:
tensor_hook = partial(self._accumulate_H_terms, module)
if is_grad_enabled() and self.steps % T == 0:
tensor_hook = partial(self._accumulate_H_terms, module, inputs)
output.register_hook(tensor_hook)

def _accumulate_H_terms(self, module: Module, grad_output: Tensor):
def _accumulate_H_terms(
self, module: Module, inputs: Tuple[Tensor], grad_output: Tensor
):
"""Accumulate the current mini-batch's contribution to `H_K, H_C` for a layer.

Updates the `H_K, H_C` buffers for the module.

Requires that the layer inputs have been stored in `self.inputs`.

Args:
module: Layer whose pre-conditioner is updated.
inputs: The layer's input tensors.
grad_output: The gradient w.r.t. the output.
"""
loss_average = self._get_param_group_entry(module, "loss_average")
kfac_approx = self._get_param_group_entry(module, "kfac_approx")
module_name = self.module_names[module]

# 1) PROCESS INPUTS AND GRAD_OUTPUTS
a = self.inputs.pop(module_name)
a = inputs[0].data
# Process into matrix according to kfac_approx
# For convolutions, unfold the input, for modules with bias terms, append a 1
a = process_input(a, module, kfac_approx)
Expand Down Expand Up @@ -660,16 +645,12 @@ def _install_hooks(
if isinstance(mod, self.SUPPORTED_MODULES)
and any(p.data_ptr() in param_ids for p in mod.parameters())
}
handles = []
for module in module_names:
handles.extend(
(
module.register_forward_pre_hook(self._save_input),
module.register_forward_hook(
self._register_tensor_hook_on_output_to_accumulate_H_terms
),
)
handles = [
module.register_forward_hook(
self._register_tensor_hook_on_output_to_accumulate_H_terms
)
for module in module_names
]
return module_names, handles

def _compute_natural_gradient(self, module: Module) -> Tuple[Tensor, ...]:
Expand Down Expand Up @@ -865,7 +846,6 @@ def set_current_grad_scale(self, grad_scale: float):
"m_Cs",
"H_Ks",
"H_Cs",
"inputs",
]

def state_dict(self) -> Dict[str, Any]:
Expand Down
7 changes: 0 additions & 7 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,9 @@ def compare_optimizers( # noqa: C901
rtol_hook = rtol_hook if rtol_hook is not None else rtol

if check_hook_quantities:
assert set(optim1.inputs.keys()) == set(optim2.inputs.keys())
assert set(optim1.H_Ks.keys()) == set(optim2.H_Ks.keys())
assert set(optim1.H_Cs.keys()) == set(optim2.H_Cs.keys())

for name in optim1.inputs:
inputs1, inputs2 = optim1.inputs[name], optim2.inputs[name]
report_nonclose(
inputs1, inputs2, atol=atol_hook, rtol=rtol_hook, name="inputs"
)

for name in optim1.H_Ks:
H_K1 = optim1.H_Ks[name].value.to_dense()
H_K2 = optim2.H_Ks[name].value.to_dense()
Expand Down
Loading