Skip to content

Commit

Permalink
[FIX] KFAC scale for loss_average="batch+sequence" (#110)
Browse files Browse the repository at this point in the history
* Fix EF scale for >2d outputs

* Fix scale of MC Fisher for >2d outputs

* Increase tolerance of CG test

* Increase CG precision

* Increase tolerance of CG test (again)

* Increase tolerance of CG test (again)

* Increase tolerance of CG test (again)

* Exclude numerically unstable cases from inverse tests

* Increase tolerance of Jacobian tests

* Adjust loss scale for KFAC-EF

* Add test for KFAC-MC with weight sharing

* Increase tolerance of KFAC (log)det tests

* Add test case for #107 (expand setting scaling issue)

* Fix KFAC scale for batch+sequence loss average

* Fix isort

* Fix test_multi_dim_output

* Change how _num_per_example_loss_terms is inferred and allow setting it explicitly

* Fix darglint and flake8

* Improve docstring for num_per_example_loss_terms

* Minor review fixes
  • Loading branch information
runame committed May 7, 2024
1 parent d43e884 commit fa502d2
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 18 deletions.
78 changes: 68 additions & 10 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class KFACLinearOperator(_LinearOperator):
)
_SUPPORTED_KFAC_APPROX: Tuple[str, ...] = ("expand", "reduce")

def __init__(
def __init__( # noqa: C901
self,
model_func: Module,
loss_func: MSELoss,
Expand All @@ -122,6 +122,7 @@ def __init__(
mc_samples: int = 1,
kfac_approx: str = "expand",
loss_average: Union[None, str] = "batch",
num_per_example_loss_terms: Optional[int] = None,
separate_weight_and_bias: bool = True,
num_data: Optional[int] = None,
batch_size_fn: Optional[Callable[[MutableMapping], int]] = None,
Expand Down Expand Up @@ -188,6 +189,16 @@ def __init__(
language modeling. If ``None``, the loss function is a sum. This
argument is used to ensure that the preconditioner is scaled
consistently with the loss and the gradient. Default: ``"batch"``.
num_per_example_loss_terms: Number of per-example loss terms, e.g., the
number of tokens in a sequence. The model outputs will have
``num_data * num_per_example_loss_terms * C`` entries, where ``C`` is
the dimension of the random variable we define the likelihood over --
for the ``CrossEntropyLoss`` it will be the number of classes, for the
``MSELoss`` and ``BCEWithLogitsLoss`` it will be the size of the last
dimension of the the model outputs/targets (our convention here).
If ``None``, ``num_per_example_loss_terms`` is inferred from the data at
the cost of one traversal through the data loader. It is expected to be
the same for all examples. Defaults to ``None``.
separate_weight_and_bias: Whether to treat weights and biases separately.
Defaults to ``True``.
num_data: Number of data points. If ``None``, it is inferred from the data
Expand All @@ -197,6 +208,7 @@ def __init__(
entry of the iterates from ``data`` and return their batch size.
Raises:
RuntimeError: If the check for deterministic behavior fails.
ValueError: If the loss function is not supported.
ValueError: If the loss average is not supported.
ValueError: If the loss average is ``None`` and the loss function's
Expand Down Expand Up @@ -261,12 +273,57 @@ def __init__(
params,
data,
progressbar=progressbar,
check_deterministic=check_deterministic,
check_deterministic=False,
shape=shape,
num_data=num_data,
batch_size_fn=batch_size_fn,
)

self._set_num_per_example_loss_terms(num_per_example_loss_terms)

if check_deterministic:
old_device = self._device
self.to_device(device("cpu"))
try:
self._check_deterministic()
except RuntimeError as e:
raise e
finally:
self.to_device(old_device)

def _set_num_per_example_loss_terms(
self, num_per_example_loss_terms: Optional[int]
):
"""Set the number of per-example loss terms.
Args:
num_per_example_loss_terms: Number of per-example loss terms. If ``None``,
it is inferred from the data at the cost of one traversal through the
data loader.
Raises:
ValueError: If the number of loss terms is not divisible by the number of
data points.
"""
if num_per_example_loss_terms is None:
# Determine the number of per-example loss terms
num_loss_terms = sum(
(
y.numel()
if isinstance(self._loss_func, CrossEntropyLoss)
else y.shape[:-1].numel()
)
for (_, y) in self._loop_over_data(desc="_num_per_example_loss_terms")
)
if num_loss_terms % self._N_data != 0:
raise ValueError(
"The number of loss terms must be divisible by the number of data "
f"points; num_loss_terms={num_loss_terms}, N_data={self._N_data}."
)
self._num_per_example_loss_terms = num_loss_terms // self._N_data
else:
self._num_per_example_loss_terms = num_per_example_loss_terms

def _reset_matrix_properties(self):
"""Reset matrix properties."""
self._trace = None
Expand Down Expand Up @@ -723,12 +780,6 @@ def _accumulate_gradient_covariance(
batch_size = g.shape[0]
if isinstance(module, Conv2d):
g = rearrange(g, "batch c o1 o2 -> batch o1 o2 c")
sequence_length = g.shape[1:-1].numel()
num_loss_terms = {
None: batch_size,
"batch": batch_size,
"batch+sequence": batch_size * sequence_length,
}[self._loss_average]

if self._kfac_approx == "expand":
# KFAC-expand approximation
Expand All @@ -737,13 +788,20 @@ def _accumulate_gradient_covariance(
# KFAC-reduce approximation
g = reduce(g, "batch ... d_out -> batch d_out", "sum")

# Compute correction for the loss scaling depending on the loss reduction used
num_loss_terms = {
None: batch_size,
"batch": batch_size,
"batch+sequence": batch_size * self._num_per_example_loss_terms,
}[self._loss_average]
# self._mc_samples will be 1 if fisher_type != "mc"
correction = {
None: 1.0 / self._mc_samples,
"batch": num_loss_terms**2 / (self._N_data * self._mc_samples),
"batch+sequence": num_loss_terms**2
/ (self._N_data * self._mc_samples * sequence_length),
/ (self._N_data * self._mc_samples * self._num_per_example_loss_terms),
}[self._loss_average]

covariance = einsum(g, g, "b i,b j->i j").mul_(correction)

if module_name not in self._gradient_covariances:
Expand Down Expand Up @@ -786,7 +844,7 @@ def _hook_accumulate_input_covariance(

if self._kfac_approx == "expand":
# KFAC-expand approximation
scale = x.shape[1:-1].numel() # sequence_length
scale = x.shape[1:-1].numel() # sequence length
x = rearrange(x, "batch ... d_in -> (batch ...) d_in")
else:
# KFAC-reduce approximation
Expand Down
106 changes: 99 additions & 7 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from test.cases import DEVICES, DEVICES_IDS
from test.utils import (
Conv2dModel,
UnetModel,
WeightShareModel,
binary_classification_targets,
classification_targets,
ggn_block_diagonal,
regression_targets,
Expand Down Expand Up @@ -415,12 +417,14 @@ def test_kfac_inplace_activations(dev: device):


@mark.parametrize("fisher_type", KFACLinearOperator._SUPPORTED_FISHER_TYPE)
@mark.parametrize("loss", [MSELoss, CrossEntropyLoss], ids=["mse", "ce"])
@mark.parametrize(
"loss", [MSELoss, CrossEntropyLoss, BCEWithLogitsLoss], ids=["mse", "ce", "bce"]
)
@mark.parametrize("reduction", ["mean", "sum"])
@mark.parametrize("dev", DEVICES, ids=DEVICES_IDS)
def test_multi_dim_output(
fisher_type: str,
loss: Union[MSELoss, CrossEntropyLoss],
loss: Union[MSELoss, CrossEntropyLoss, BCEWithLogitsLoss],
reduction: str,
dev: device,
):
Expand All @@ -436,17 +440,26 @@ def test_multi_dim_output(
# set up loss function, data, and model
loss_func = loss(reduction=reduction).to(dev)
loss_average = None if reduction == "sum" else "batch+sequence"
X1 = rand(2, 7, 5, 5)
X2 = rand(4, 7, 5, 5)
if isinstance(loss_func, MSELoss):
data = [
(rand(2, 7, 5, 5), regression_targets((2, 7, 5, 3))),
(rand(4, 7, 5, 5), regression_targets((4, 7, 5, 3))),
(X1, regression_targets((2, 7, 5, 3))),
(X2, regression_targets((4, 7, 5, 3))),
]
manual_seed(711)
model = Sequential(Linear(5, 4), Linear(4, 3)).to(dev)
elif issubclass(loss, BCEWithLogitsLoss):
data = [
(X1, binary_classification_targets((2, 7, 5, 3))),
(X2, binary_classification_targets((4, 7, 5, 3))),
]
manual_seed(711)
model = Sequential(Linear(5, 4), Linear(4, 3)).to(dev)
else:
data = [
(rand(2, 7, 5, 5), classification_targets((2, 7, 5), 3)),
(rand(4, 7, 5, 5), classification_targets((4, 7, 5), 3)),
(X1, classification_targets((2, 7, 5), 3)),
(X2, classification_targets((4, 7, 5), 3)),
]
manual_seed(711)
# rearrange is necessary to get the expected output shape for ce loss
Expand Down Expand Up @@ -479,7 +492,7 @@ def test_multi_dim_output(
data_flat = [
(
(x, y.flatten(start_dim=0, end_dim=-2))
if isinstance(loss_func, MSELoss)
if isinstance(loss_func, (MSELoss, BCEWithLogitsLoss))
else (x, y.flatten(start_dim=0))
)
for x, y in data
Expand All @@ -497,6 +510,85 @@ def test_multi_dim_output(
report_nonclose(kfac_mat, kfac_flat_mat)


@mark.parametrize("fisher_type", KFACLinearOperator._SUPPORTED_FISHER_TYPE)
@mark.parametrize(
"loss", [MSELoss, CrossEntropyLoss, BCEWithLogitsLoss], ids=["mse", "ce", "bce"]
)
@mark.parametrize("dev", DEVICES, ids=DEVICES_IDS)
def test_expand_setting_scaling(
fisher_type: str,
loss: Union[MSELoss, CrossEntropyLoss, BCEWithLogitsLoss],
dev: device,
):
"""Test KFAC for correct scaling for expand setting with mean reduction loss.
See #107 for details.
Args:
fisher_type: The type of Fisher matrix to use.
loss: The loss function to use.
dev: The device to run the test on.
"""
manual_seed(0)

# set up data, loss function, and model
X1 = rand(2, 3, 32, 32)
X2 = rand(4, 3, 32, 32)
if issubclass(loss, MSELoss):
data = [
(X1, regression_targets((2, 32, 32, 3))),
(X2, regression_targets((4, 32, 32, 3))),
]
elif issubclass(loss, BCEWithLogitsLoss):
data = [
(X1, binary_classification_targets((2, 32, 32, 3))),
(X2, binary_classification_targets((4, 32, 32, 3))),
]
else:
data = [
(X1, classification_targets((2, 32, 32), 3)),
(X2, classification_targets((4, 32, 32), 3)),
]
model = UnetModel(loss).to(dev)
params = list(model.parameters())

# KFAC with sum reduction
loss_func = loss(reduction="sum").to(dev)
kfac_sum = KFACLinearOperator(
model,
loss_func,
params,
data,
fisher_type=fisher_type,
loss_average=None,
)
# FOOF does not scale the gradient covariances, even when using a mean reduction
if fisher_type != "forward-only":
# Simulate a mean reduction by manually scaling the gradient covariances
loss_term_factor = 32 * 32 # number of spatial locations of model output
if issubclass(loss, (MSELoss, BCEWithLogitsLoss)):
output_random_variable_size = 3
# MSE loss averages over number of output channels
loss_term_factor *= output_random_variable_size
for ggT in kfac_sum._gradient_covariances.values():
ggT /= kfac_sum._N_data * loss_term_factor
kfac_simulated_mean_mat = kfac_sum @ eye(kfac_sum.shape[1])

# KFAC with mean reduction
loss_func = loss(reduction="mean").to(dev)
kfac_mean = KFACLinearOperator(
model,
loss_func,
params,
data,
fisher_type=fisher_type,
loss_average="batch+sequence",
)
kfac_mean_mat = kfac_mean @ eye(kfac_mean.shape[1])

report_nonclose(kfac_simulated_mean_mat, kfac_mean_mat)


def test_bug_device_change_invalidates_parameter_mapping():
"""Reproduce #77: Loading KFAC from GPU to CPU invalidates the internal mapping.
Expand Down
49 changes: 48 additions & 1 deletion test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,19 @@
from einops.layers.torch import Rearrange
from numpy import eye, ndarray
from torch import Tensor, cat, cuda, device, dtype, from_numpy, rand, randint
from torch.nn import AdaptiveAvgPool2d, Conv2d, Flatten, Module, Parameter, Sequential
from torch.nn import (
AdaptiveAvgPool2d,
BCEWithLogitsLoss,
Conv2d,
CrossEntropyLoss,
Flatten,
Identity,
Module,
MSELoss,
Parameter,
Sequential,
Upsample,
)

from curvlinops import GGNLinearOperator

Expand Down Expand Up @@ -287,6 +299,41 @@ def forward(self, x: Tensor) -> Tensor:
return self._model(x)


class UnetModel(Module):
"""Simple Unet-like model where the number of spatial locations varies."""

def __init__(self, loss: Module):
"""Initialize the model."""
if loss not in {MSELoss, CrossEntropyLoss, BCEWithLogitsLoss}:
raise ValueError(
"Loss has to be one of MSELoss, CrossEntropyLoss, BCEWithLogitsLoss. "
f"Got {loss}."
)
super().__init__()
self._model = Sequential(
Conv2d(3, 2, 3, padding=1, stride=2),
Conv2d(2, 2, 3, padding=3 // 2),
Upsample(scale_factor=2, mode="nearest"),
Conv2d(2, 3, 3, padding=1),
(
Rearrange("batch c h w -> batch h w c")
if issubclass(loss, (MSELoss, BCEWithLogitsLoss))
else Identity()
),
)

def forward(self, x: Tensor) -> Tensor:
"""Forward pass of the model.
Args:
x: Input to the forward pass.
Returns:
Output of the model.
"""
return self._model(x)


def cast_input(
X: Union[Tensor, MutableMapping], target_dtype: dtype
) -> Union[Tensor, MutableMapping]:
Expand Down

0 comments on commit fa502d2

Please sign in to comment.