From fa502d2d5780985fa4aa3bb472dc6daaa4f2af4f Mon Sep 17 00:00:00 2001 From: Runa Eschenhagen <33333409+runame@users.noreply.github.com> Date: Tue, 7 May 2024 15:51:01 +0100 Subject: [PATCH] [FIX] KFAC scale for `loss_average="batch+sequence"` (#110) * Fix EF scale for >2d outputs * Fix scale of MC Fisher for >2d outputs * Increase tolerance of CG test * Increase CG precision * Increase tolerance of CG test (again) * Increase tolerance of CG test (again) * Increase tolerance of CG test (again) * Exclude numerically unstable cases from inverse tests * Increase tolerance of Jacobian tests * Adjust loss scale for KFAC-EF * Add test for KFAC-MC with weight sharing * Increase tolerance of KFAC (log)det tests * Add test case for #107 (expand setting scaling issue) * Fix KFAC scale for batch+sequence loss average * Fix isort * Fix test_multi_dim_output * Change how _num_per_example_loss_terms is inferred and allow setting it explicitly * Fix darglint and flake8 * Improve docstring for num_per_example_loss_terms * Minor review fixes --- curvlinops/kfac.py | 78 ++++++++++++++++++++++++++++----- test/test_kfac.py | 106 ++++++++++++++++++++++++++++++++++++++++++--- test/utils.py | 49 ++++++++++++++++++++- 3 files changed, 215 insertions(+), 18 deletions(-) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index d394291..4909d8e 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -108,7 +108,7 @@ class KFACLinearOperator(_LinearOperator): ) _SUPPORTED_KFAC_APPROX: Tuple[str, ...] = ("expand", "reduce") - def __init__( + def __init__( # noqa: C901 self, model_func: Module, loss_func: MSELoss, @@ -122,6 +122,7 @@ def __init__( mc_samples: int = 1, kfac_approx: str = "expand", loss_average: Union[None, str] = "batch", + num_per_example_loss_terms: Optional[int] = None, separate_weight_and_bias: bool = True, num_data: Optional[int] = None, batch_size_fn: Optional[Callable[[MutableMapping], int]] = None, @@ -188,6 +189,16 @@ def __init__( language modeling. If ``None``, the loss function is a sum. This argument is used to ensure that the preconditioner is scaled consistently with the loss and the gradient. Default: ``"batch"``. + num_per_example_loss_terms: Number of per-example loss terms, e.g., the + number of tokens in a sequence. The model outputs will have + ``num_data * num_per_example_loss_terms * C`` entries, where ``C`` is + the dimension of the random variable we define the likelihood over -- + for the ``CrossEntropyLoss`` it will be the number of classes, for the + ``MSELoss`` and ``BCEWithLogitsLoss`` it will be the size of the last + dimension of the the model outputs/targets (our convention here). + If ``None``, ``num_per_example_loss_terms`` is inferred from the data at + the cost of one traversal through the data loader. It is expected to be + the same for all examples. Defaults to ``None``. separate_weight_and_bias: Whether to treat weights and biases separately. Defaults to ``True``. num_data: Number of data points. If ``None``, it is inferred from the data @@ -197,6 +208,7 @@ def __init__( entry of the iterates from ``data`` and return their batch size. Raises: + RuntimeError: If the check for deterministic behavior fails. ValueError: If the loss function is not supported. ValueError: If the loss average is not supported. ValueError: If the loss average is ``None`` and the loss function's @@ -261,12 +273,57 @@ def __init__( params, data, progressbar=progressbar, - check_deterministic=check_deterministic, + check_deterministic=False, shape=shape, num_data=num_data, batch_size_fn=batch_size_fn, ) + self._set_num_per_example_loss_terms(num_per_example_loss_terms) + + if check_deterministic: + old_device = self._device + self.to_device(device("cpu")) + try: + self._check_deterministic() + except RuntimeError as e: + raise e + finally: + self.to_device(old_device) + + def _set_num_per_example_loss_terms( + self, num_per_example_loss_terms: Optional[int] + ): + """Set the number of per-example loss terms. + + Args: + num_per_example_loss_terms: Number of per-example loss terms. If ``None``, + it is inferred from the data at the cost of one traversal through the + data loader. + + Raises: + ValueError: If the number of loss terms is not divisible by the number of + data points. + """ + if num_per_example_loss_terms is None: + # Determine the number of per-example loss terms + num_loss_terms = sum( + ( + y.numel() + if isinstance(self._loss_func, CrossEntropyLoss) + else y.shape[:-1].numel() + ) + for (_, y) in self._loop_over_data(desc="_num_per_example_loss_terms") + ) + if num_loss_terms % self._N_data != 0: + raise ValueError( + "The number of loss terms must be divisible by the number of data " + f"points; num_loss_terms={num_loss_terms}, N_data={self._N_data}." + ) + self._num_per_example_loss_terms = num_loss_terms // self._N_data + else: + self._num_per_example_loss_terms = num_per_example_loss_terms + def _reset_matrix_properties(self): """Reset matrix properties.""" self._trace = None @@ -723,12 +780,6 @@ def _accumulate_gradient_covariance( batch_size = g.shape[0] if isinstance(module, Conv2d): g = rearrange(g, "batch c o1 o2 -> batch o1 o2 c") - sequence_length = g.shape[1:-1].numel() - num_loss_terms = { - None: batch_size, - "batch": batch_size, - "batch+sequence": batch_size * sequence_length, - }[self._loss_average] if self._kfac_approx == "expand": # KFAC-expand approximation @@ -737,13 +788,20 @@ def _accumulate_gradient_covariance( # KFAC-reduce approximation g = reduce(g, "batch ... d_out -> batch d_out", "sum") + # Compute correction for the loss scaling depending on the loss reduction used + num_loss_terms = { + None: batch_size, + "batch": batch_size, + "batch+sequence": batch_size * self._num_per_example_loss_terms, + }[self._loss_average] # self._mc_samples will be 1 if fisher_type != "mc" correction = { None: 1.0 / self._mc_samples, "batch": num_loss_terms**2 / (self._N_data * self._mc_samples), "batch+sequence": num_loss_terms**2 - / (self._N_data * self._mc_samples * sequence_length), + / (self._N_data * self._mc_samples * self._num_per_example_loss_terms), }[self._loss_average] + covariance = einsum(g, g, "b i,b j->i j").mul_(correction) if module_name not in self._gradient_covariances: @@ -786,7 +844,7 @@ def _hook_accumulate_input_covariance( if self._kfac_approx == "expand": # KFAC-expand approximation - scale = x.shape[1:-1].numel() # sequence_length + scale = x.shape[1:-1].numel() # sequence length x = rearrange(x, "batch ... d_in -> (batch ...) d_in") else: # KFAC-reduce approximation diff --git a/test/test_kfac.py b/test/test_kfac.py index 16c812a..06394dd 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -3,7 +3,9 @@ from test.cases import DEVICES, DEVICES_IDS from test.utils import ( Conv2dModel, + UnetModel, WeightShareModel, + binary_classification_targets, classification_targets, ggn_block_diagonal, regression_targets, @@ -415,12 +417,14 @@ def test_kfac_inplace_activations(dev: device): @mark.parametrize("fisher_type", KFACLinearOperator._SUPPORTED_FISHER_TYPE) -@mark.parametrize("loss", [MSELoss, CrossEntropyLoss], ids=["mse", "ce"]) +@mark.parametrize( + "loss", [MSELoss, CrossEntropyLoss, BCEWithLogitsLoss], ids=["mse", "ce", "bce"] +) @mark.parametrize("reduction", ["mean", "sum"]) @mark.parametrize("dev", DEVICES, ids=DEVICES_IDS) def test_multi_dim_output( fisher_type: str, - loss: Union[MSELoss, CrossEntropyLoss], + loss: Union[MSELoss, CrossEntropyLoss, BCEWithLogitsLoss], reduction: str, dev: device, ): @@ -436,17 +440,26 @@ def test_multi_dim_output( # set up loss function, data, and model loss_func = loss(reduction=reduction).to(dev) loss_average = None if reduction == "sum" else "batch+sequence" + X1 = rand(2, 7, 5, 5) + X2 = rand(4, 7, 5, 5) if isinstance(loss_func, MSELoss): data = [ - (rand(2, 7, 5, 5), regression_targets((2, 7, 5, 3))), - (rand(4, 7, 5, 5), regression_targets((4, 7, 5, 3))), + (X1, regression_targets((2, 7, 5, 3))), + (X2, regression_targets((4, 7, 5, 3))), + ] + manual_seed(711) + model = Sequential(Linear(5, 4), Linear(4, 3)).to(dev) + elif issubclass(loss, BCEWithLogitsLoss): + data = [ + (X1, binary_classification_targets((2, 7, 5, 3))), + (X2, binary_classification_targets((4, 7, 5, 3))), ] manual_seed(711) model = Sequential(Linear(5, 4), Linear(4, 3)).to(dev) else: data = [ - (rand(2, 7, 5, 5), classification_targets((2, 7, 5), 3)), - (rand(4, 7, 5, 5), classification_targets((4, 7, 5), 3)), + (X1, classification_targets((2, 7, 5), 3)), + (X2, classification_targets((4, 7, 5), 3)), ] manual_seed(711) # rearrange is necessary to get the expected output shape for ce loss @@ -479,7 +492,7 @@ def test_multi_dim_output( data_flat = [ ( (x, y.flatten(start_dim=0, end_dim=-2)) - if isinstance(loss_func, MSELoss) + if isinstance(loss_func, (MSELoss, BCEWithLogitsLoss)) else (x, y.flatten(start_dim=0)) ) for x, y in data @@ -497,6 +510,85 @@ def test_multi_dim_output( report_nonclose(kfac_mat, kfac_flat_mat) +@mark.parametrize("fisher_type", KFACLinearOperator._SUPPORTED_FISHER_TYPE) +@mark.parametrize( + "loss", [MSELoss, CrossEntropyLoss, BCEWithLogitsLoss], ids=["mse", "ce", "bce"] +) +@mark.parametrize("dev", DEVICES, ids=DEVICES_IDS) +def test_expand_setting_scaling( + fisher_type: str, + loss: Union[MSELoss, CrossEntropyLoss, BCEWithLogitsLoss], + dev: device, +): + """Test KFAC for correct scaling for expand setting with mean reduction loss. + + See #107 for details. + + Args: + fisher_type: The type of Fisher matrix to use. + loss: The loss function to use. + dev: The device to run the test on. + """ + manual_seed(0) + + # set up data, loss function, and model + X1 = rand(2, 3, 32, 32) + X2 = rand(4, 3, 32, 32) + if issubclass(loss, MSELoss): + data = [ + (X1, regression_targets((2, 32, 32, 3))), + (X2, regression_targets((4, 32, 32, 3))), + ] + elif issubclass(loss, BCEWithLogitsLoss): + data = [ + (X1, binary_classification_targets((2, 32, 32, 3))), + (X2, binary_classification_targets((4, 32, 32, 3))), + ] + else: + data = [ + (X1, classification_targets((2, 32, 32), 3)), + (X2, classification_targets((4, 32, 32), 3)), + ] + model = UnetModel(loss).to(dev) + params = list(model.parameters()) + + # KFAC with sum reduction + loss_func = loss(reduction="sum").to(dev) + kfac_sum = KFACLinearOperator( + model, + loss_func, + params, + data, + fisher_type=fisher_type, + loss_average=None, + ) + # FOOF does not scale the gradient covariances, even when using a mean reduction + if fisher_type != "forward-only": + # Simulate a mean reduction by manually scaling the gradient covariances + loss_term_factor = 32 * 32 # number of spatial locations of model output + if issubclass(loss, (MSELoss, BCEWithLogitsLoss)): + output_random_variable_size = 3 + # MSE loss averages over number of output channels + loss_term_factor *= output_random_variable_size + for ggT in kfac_sum._gradient_covariances.values(): + ggT /= kfac_sum._N_data * loss_term_factor + kfac_simulated_mean_mat = kfac_sum @ eye(kfac_sum.shape[1]) + + # KFAC with mean reduction + loss_func = loss(reduction="mean").to(dev) + kfac_mean = KFACLinearOperator( + model, + loss_func, + params, + data, + fisher_type=fisher_type, + loss_average="batch+sequence", + ) + kfac_mean_mat = kfac_mean @ eye(kfac_mean.shape[1]) + + report_nonclose(kfac_simulated_mean_mat, kfac_mean_mat) + + def test_bug_device_change_invalidates_parameter_mapping(): """Reproduce #77: Loading KFAC from GPU to CPU invalidates the internal mapping. diff --git a/test/utils.py b/test/utils.py index c462df7..07e6b82 100644 --- a/test/utils.py +++ b/test/utils.py @@ -8,7 +8,19 @@ from einops.layers.torch import Rearrange from numpy import eye, ndarray from torch import Tensor, cat, cuda, device, dtype, from_numpy, rand, randint -from torch.nn import AdaptiveAvgPool2d, Conv2d, Flatten, Module, Parameter, Sequential +from torch.nn import ( + AdaptiveAvgPool2d, + BCEWithLogitsLoss, + Conv2d, + CrossEntropyLoss, + Flatten, + Identity, + Module, + MSELoss, + Parameter, + Sequential, + Upsample, +) from curvlinops import GGNLinearOperator @@ -287,6 +299,41 @@ def forward(self, x: Tensor) -> Tensor: return self._model(x) +class UnetModel(Module): + """Simple Unet-like model where the number of spatial locations varies.""" + + def __init__(self, loss: Module): + """Initialize the model.""" + if loss not in {MSELoss, CrossEntropyLoss, BCEWithLogitsLoss}: + raise ValueError( + "Loss has to be one of MSELoss, CrossEntropyLoss, BCEWithLogitsLoss. " + f"Got {loss}." + ) + super().__init__() + self._model = Sequential( + Conv2d(3, 2, 3, padding=1, stride=2), + Conv2d(2, 2, 3, padding=3 // 2), + Upsample(scale_factor=2, mode="nearest"), + Conv2d(2, 3, 3, padding=1), + ( + Rearrange("batch c h w -> batch h w c") + if issubclass(loss, (MSELoss, BCEWithLogitsLoss)) + else Identity() + ), + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass of the model. + + Args: + x: Input to the forward pass. + + Returns: + Output of the model. + """ + return self._model(x) + + def cast_input( X: Union[Tensor, MutableMapping], target_dtype: dtype ) -> Union[Tensor, MutableMapping]: