Skip to content

Commit

Permalink
Refactor batch_averaged argument
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Oct 31, 2023
1 parent 43bd6f2 commit dc5dd92
Show file tree
Hide file tree
Showing 14 changed files with 111 additions and 47 deletions.
2 changes: 1 addition & 1 deletion docs/examples/example_03_param_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"momentum": 0.9,
"weight_decay": 1e-2,
"lr_cov": 1e-2,
"batch_averaged": True,
"batch_averaged": "batch",
"T": 1,
"alpha1": 0.5,
}
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/example_04_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
"momentum": 0.9,
"weight_decay": 1e-2,
"lr_cov": 1e-2,
"batch_averaged": True,
"batch_averaged": "batch",
"T": 1,
"alpha1": 0.5,
"structures": ("dense", "dense"),
Expand Down
8 changes: 4 additions & 4 deletions singd/optim/accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ class BatchAccumulator:
The quantity's class must support multiplication with a scalar and addition.
"""

def __init__(self, batch_averaged: bool = True):
def __init__(self, averaged: bool = True):
"""Initialize the accumulator.
Args:
batch_averaged: Whether the quantity is averaged over the batch.
averaged: Whether the quantity is averaged over the batch.
If ``False``, assumes sum. Default: ``True``.
"""
self.value = None
self.batch_averaged = batch_averaged
self.averaged = averaged
self.batch_size_total = 0

def update(self, other: Any, batch_size: int):
Expand All @@ -33,7 +33,7 @@ def update(self, other: Any, batch_size: int):
self.value = other
self.batch_size_total = batch_size
else:
if self.batch_averaged:
if self.averaged:
scale = self.batch_size_total / (self.batch_size_total + batch_size)
self.value = self.value * scale + other * (1 - scale)
else:
Expand Down
23 changes: 19 additions & 4 deletions singd/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
alpha1: float = 0.5, # α₁ in the paper
weight_decay: float = 0.0, # γ in the paper
T: int = 10, # T in the paper
batch_averaged: bool = True,
batch_averaged: Union[None, str] = "batch",
lr_cov: Union[float, Callable[[int], float]] = 1e-2, # β₁ in the paper
structures: Tuple[str, str] = ("dense", "dense"),
kfac_approx: str = "expand",
Expand Down Expand Up @@ -123,7 +123,14 @@ def __init__(
Default: `0.0`.
T: Pre-conditioner update frequency. Default: `10`.
batch_averaged: Whether the loss function is a mean over per-sample
losses. Default is `True`. If `False `, the loss function is a sum.
losses and if yes, over which dimensions the mean is taken.
If `"batch"`, the loss function is a mean over as many terms as
the size of the mini-batch. If `"batch+sequence"`, the loss
function is a mean over as many terms as the size of the
mini-batch times the sequence length, e.g. in the case of
language modeling. If `None`, the loss function is a sum. This
arugment is used to ensure that the preconditioner is scaled
consistently with the loss and the gradient. Default: `"batch"`.
lr_cov: (β₁ in the paper) Learning rate for the updates of the pre-
conditioner momenta \\(\\mathbf{m}_\\mathbf{K}\\) and
\\(\\mathbf{m}_\\mathbf{C}\\). Default is `1e-2`. Also allows for a
Expand Down Expand Up @@ -299,6 +306,13 @@ def _check_param_groups(self, model: Module) -> Dict[int, int]:
"kfac_approx has to be set to either 'expand' or 'reduce', "
f"but was set to {group['kfac_approx']}."
)
if group["batch_averaged"] is not None:
if group["batch_averaged"] not in ["batch", "batch+sequence"]:
raise ValueError(
"batch_averaged has to be set to either None, 'batch', "
"or 'batch+sequence', but was set to "
f"{group['batch_averaged']}."
)

# Find out which parameter is in which group
param_to_group_idx = {}
Expand Down Expand Up @@ -585,10 +599,11 @@ def _accumulate_H_terms(
H_C.all_reduce(op=op)

# maybe set up fresh accumulators (they get flushed in `.step`)
averaged = batch_averaged is not None
if module_name not in self.H_Ks:
self.H_Ks[module_name] = BatchAccumulator(batch_averaged=batch_averaged)
self.H_Ks[module_name] = BatchAccumulator(averaged=averaged)
if module_name not in self.H_Cs:
self.H_Cs[module_name] = BatchAccumulator(batch_averaged=batch_averaged)
self.H_Cs[module_name] = BatchAccumulator(averaged=averaged)

self.H_Ks[module_name].update(H_K, batch_size)
self.H_Cs[module_name].update(H_C, batch_size)
Expand Down
65 changes: 52 additions & 13 deletions singd/optim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,25 @@ def linear_process_input(x: Tensor, layer: Linear, kfac_approx: str) -> Tensor:


def process_grad_output(
grad_output: Tensor, module: Module, batch_averaged: bool, kfac_approx: str
grad_output: Tensor,
module: Module,
batch_averaged: Union[None, str],
kfac_approx: str,
) -> Tensor:
"""Reshape output gradients into matrices and apply scaling.
Args:
grad_output: The gradient w.r.t. the output of the module.
module: The module.
batch_averaged: Whether the loss is a mean over per-sample losses.
batch_averaged: Whether the loss function is a mean over per-sample
losses and if yes, over which dimensions the mean is taken.
If `"batch"`, the loss function is a mean over as many terms as
the size of the mini-batch. If `"batch+sequence"`, the loss
function is a mean over as many terms as the size of the
mini-batch times the sequence length, e.g. in the case of
language modeling. If `None`, the loss function is a sum. This
arugment is used to ensure that the preconditioner is scaled
consistently with the loss and the gradient. Default: `"batch"`.
kfac_approx: The KFAC approximation to use for linear weight-sharing
layers. Possible values are `"expand"` and `"reduce"`.
Expand All @@ -159,6 +170,7 @@ def process_grad_output(
AssertionError: If `kfac_approx` is neither `"expand"` nor `"reduce"`.
NotImplementedError: If the module is not supported.
"""
assert batch_averaged in {None, "batch", "batch+sequence"}
assert kfac_approx in {"expand", "reduce"}
grad_scaling = 1.0
if isinstance(module, Conv2d):
Expand All @@ -174,14 +186,22 @@ def process_grad_output(


def conv2d_process_grad_output(
g: Tensor, batch_averaged: bool, scaling: float, kfac_approx: str
g: Tensor, batch_averaged: Union[None, str], scaling: float, kfac_approx: str
) -> Tensor:
"""Process the output gradient of a convolution before the self-inner product.
Args:
g: Gradient w.r.t. the output of a convolution. Has shape
`[batch_size, C_out, O1, O2]`.
batch_averaged: Whether to multiply with the batch size.
batch_averaged: Whether the loss function is a mean over per-sample
losses and if yes, over which dimensions the mean is taken.
If `"batch"`, the loss function is a mean over as many terms as
the size of the mini-batch. If `"batch+sequence"`, the loss
function is a mean over as many terms as the size of the
mini-batch times the sequence length, e.g. in the case of
language modeling. If `None`, the loss function is a sum. This
arugment is used to ensure that the preconditioner is scaled
consistently with the loss and the gradient. Default: `"batch"`.
scaling: An additional scaling that will be applied to the gradient.
kfac_approx: The KFAC approximation to use. Possible values are
`"expand"` and `"reduce"`.
Expand All @@ -190,11 +210,14 @@ def conv2d_process_grad_output(
The processed scaled gradient. Has shape `[batch_size, C_out]` for
`"reduce"` and `[batch_size * O1 * O2, C_out]` for `"expand"`.
"""
# The scaling by `sqrt(batch_size)` when `batch_averaged=True` assumes
# that we are in the reduce setting, i.e. the number of loss terms equals
# the batch size.
batch_size = g.shape[0]
scaling = scaling * sqrt(batch_size) if batch_averaged else scaling
spatial_size = g.shape[2] * g.shape[3]
# We have to adjust the scaling to account for the mean reduction of the
# loss used for computing the gradients when batch_averaged is not None.
num_loss_terms = (
batch_size * spatial_size if batch_averaged == "batch+sequence" else batch_size
)
scaling = scaling * sqrt(num_loss_terms) if batch_averaged else scaling

if kfac_approx == "expand":
# KFAC-expand approximation
Expand All @@ -207,15 +230,23 @@ def conv2d_process_grad_output(


def linear_process_grad_output(
g: Tensor, batch_averaged: bool, scaling: float, kfac_approx: str
g: Tensor, batch_averaged: Union[None, str], scaling: float, kfac_approx: str
) -> Tensor:
"""Process the output gradient of a linear layer before the self-inner product.
Args:
g: Gradient w.r.t. the output of a linear layer. Has shape
`[batch_size, ..., d_out]` where `...` is an arbitrary number of
weight-shared dimensions.
batch_averaged: Whether to multiply with the batch size.
batch_averaged: Whether the loss function is a mean over per-sample
losses and if yes, over which dimensions the mean is taken.
If `"batch"`, the loss function is a mean over as many terms as
the size of the mini-batch. If `"batch+sequence"`, the loss
function is a mean over as many terms as the size of the
mini-batch times the sequence length, e.g. in the case of
language modeling. If `None`, the loss function is a sum. This
arugment is used to ensure that the preconditioner is scaled
consistently with the loss and the gradient. Default: `"batch"`.
scaling: An additional scaling that will be applied to the gradient.
kfac_approx: The KFAC approximation to use for linear weight-sharing
layers. Possible values are `"expand"` and `"reduce"`.
Expand All @@ -224,14 +255,22 @@ def linear_process_grad_output(
The processed gradient. Has shape `[batch_size, d_out]` for `"reduce"`
and `[batch_size * ..., d_out]` for `"expand"`.
"""
batch_size = g.shape[0]
weight_shared_dims_size = g[0, ..., 0].numel()
# We have to adjust the scaling to account for the mean reduction of the
# loss used for computing the gradients when batch_averaged is not None.
num_loss_terms = (
batch_size * weight_shared_dims_size
if batch_averaged == "batch+sequence"
else batch_size
)
scaling = scaling * sqrt(num_loss_terms) if batch_averaged else scaling

if kfac_approx == "expand":
# KFAC-expand approximation
g = rearrange(g, "b ... d_out -> (b ...) d_out")
else:
# KFAC-reduce approximation
g = reduce(g, "b ... d_out -> b d_out", "sum")

# The use of `g.shape[0]` assumes that the setting of the loss, i.e. the
# number of loss terms, matches the `kfac_approx` that is used.
scaling = scaling * sqrt(g.shape[0]) if batch_averaged else scaling
return g * scaling
2 changes: 1 addition & 1 deletion test/optim/test_autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_autocast():
"momentum": 0.9,
"weight_decay": 1e-2,
"lr_cov": 1e-2,
"batch_averaged": True,
"batch_averaged": "batch",
"T": 1,
"alpha1": 0.5,
"structures": ("dense", "dense"),
Expand Down
2 changes: 1 addition & 1 deletion test/optim/test_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def setup() -> Tuple[Sequential, Module, SINGD]:
"momentum": 0.9,
"weight_decay": 1e-2,
"lr_cov": 1e-2,
"batch_averaged": True,
"batch_averaged": "batch",
"T": 1,
"alpha1": 0.5,
"structures": ("dense", "dense"),
Expand Down
32 changes: 22 additions & 10 deletions test/optim/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from test.optim.utils import Transpose, jacobians_naive
from test.utils import DEVICE_IDS, DEVICES
from typing import Callable, List, Tuple
from typing import Callable, List, Tuple, Union

from einops import rearrange, reduce
from pytest import mark
Expand Down Expand Up @@ -88,15 +88,13 @@ def conv2d_model(setting: str, bias: bool) -> Sequential:

@mark.parametrize("model", MODELS.items(), ids=MODELS.keys())
@mark.parametrize("setting", ["expand", "reduce"])
@mark.parametrize(
"batch_averaged", [True, False], ids=["batch_averaged", "not_averaged"]
)
@mark.parametrize("averaged", [True, False], ids=["batch_averaged", "not_averaged"])
@mark.parametrize("bias", [True, False], ids=["bias", "no_bias"])
@mark.parametrize("device", DEVICES, ids=DEVICE_IDS)
def test_kfac(
model: Tuple[str, Callable],
setting: str,
batch_averaged: bool,
averaged: bool,
bias: bool,
device: device,
):
Expand All @@ -109,7 +107,7 @@ def test_kfac(
model: Tuple of model name and function takes `bias` as input and
returns the model.
setting: KFAC approximation setting. Either `"expand"` or `"reduce"`.
batch_averaged: Whether to average over the batch dimension.
averaged: Whether the loss uses a mean reduction.
bias: Whether to use a bias term.
device: Device to run the test on.
Expand All @@ -122,8 +120,14 @@ def test_kfac(
# Setup model and inputs x.
model_name, model_fn = model

if model_name == "conv2d" and setting == "expand" and batch_averaged:
return # TODO This case will work when issue #31 is fixed.
# Set appropriate batch_averaged argument based on averaged and setting.
if averaged:
if setting == "expand":
batch_averaged = "batch+sequence"
else:
batch_averaged = "batch"
else:
batch_averaged = None

if model_name == "conv2d":
model: Module = model_fn(setting, bias)
Expand Down Expand Up @@ -211,14 +215,22 @@ def forward(self, x: Tensor, setting: str) -> Tensor:
class KFACMSE:
"""Class for computing the KFAC approximation with the MSE loss."""

def __init__(self, model: Module, batch_averaged: bool, setting: str):
def __init__(self, model: Module, batch_averaged: Union[None, str], setting: str):
"""Initialize the KFAC approximation class.
Installs forward and backward hooks to the model.
Args:
model: The model.
batch_averaged: Whether the loss is a mean over per-sample losses.
batch_averaged: Whether the loss function is a mean over per-sample
losses and if yes, over which dimensions the mean is taken.
If `"batch"`, the loss function is a mean over as many terms as
the size of the mini-batch. If `"batch+sequence"`, the loss
function is a mean over as many terms as the size of the
mini-batch times the sequence length, e.g. in the case of
language modeling. If `None`, the loss function is a sum. This
arugment is used to ensure that the preconditioner is scaled
consistently with the loss and the gradient. Default: `"batch"`.
setting: KFAC approximation setting. Possible values are `'expand'`
and `'reduce'`.
"""
Expand Down
5 changes: 2 additions & 3 deletions test/optim/test_lin2023simplifying.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def test_compare_lin2023simplifying(): # noqa: C901
momentum = 0.9
weight_decay = 1e-2
lr_cov = 1e-2
batch_averaged = True
T = 1
alpha1_beta2 = 0.5

Expand All @@ -74,7 +73,7 @@ def test_compare_lin2023simplifying(): # noqa: C901
TInv=T,
faster=True,
lr_cov=lr_cov,
batch_averaged=batch_averaged,
batch_averaged=True,
)

def lr_cov_schedule(step: int) -> float:
Expand All @@ -101,7 +100,7 @@ def lr_cov_schedule(step: int) -> float:
damping=damping,
alpha1=alpha1_beta2,
weight_decay=weight_decay,
batch_averaged=batch_averaged,
batch_averaged="batch",
T=T,
lr_cov=lr_cov_schedule,
structures=("dense", "dense"),
Expand Down
5 changes: 2 additions & 3 deletions test/optim/test_lin2023simplifying_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def test_compare_lin2023simplifying_ddp(): # noqa: C901
momentum = 0.9
weight_decay = 1e-2
lr_cov = 1e-2
batch_averaged = True
T = 1
alpha1_beta2 = 0.5

Expand All @@ -95,7 +94,7 @@ def test_compare_lin2023simplifying_ddp(): # noqa: C901
TInv=T,
faster=True,
lr_cov=lr_cov,
batch_averaged=batch_averaged,
batch_averaged=True,
)
optim_ours = SINGD(
model_ours,
Expand All @@ -105,7 +104,7 @@ def test_compare_lin2023simplifying_ddp(): # noqa: C901
damping=damping,
alpha1=alpha1_beta2,
weight_decay=weight_decay,
batch_averaged=batch_averaged,
batch_averaged="batch",
T=T,
lr_cov=lr_cov,
structures=("dense", "dense"),
Expand Down
2 changes: 1 addition & 1 deletion test/optim/test_micro_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_micro_batches():
"momentum": 0.9,
"weight_decay": 1e-2,
"lr_cov": 1e-2,
"batch_averaged": True,
"batch_averaged": "batch",
"T": 1,
"alpha1": 0.5,
"structures": ("dense", "dense"),
Expand Down
Loading

0 comments on commit dc5dd92

Please sign in to comment.