From 81b4709abb5ba23b347108e0b052c9850afec587 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Thu, 2 Nov 2023 18:17:36 -0400 Subject: [PATCH 1/3] [ADD] Use TN formulation of Dangel, 2023 to compute average patches --- setup.cfg | 2 + singd/optim/utils.py | 72 ++++++++++++++++++-- test/optim/test_utils.py | 137 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 204 insertions(+), 7 deletions(-) create mode 100644 test/optim/test_utils.py diff --git a/setup.cfg b/setup.cfg index 1174fa1..3ddaccc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -55,6 +55,8 @@ test = torchvision matplotlib # visual structure tests imageio # visual structure tests + memory_profiler # for measuring memory consumption on CPU + codetiming # for measuring run time # Dependencies needed to run the tests (semicolon/line-separated) lint = diff --git a/singd/optim/utils.py b/singd/optim/utils.py index b00e1ec..4ea8caf 100644 --- a/singd/optim/utils.py +++ b/singd/optim/utils.py @@ -4,8 +4,9 @@ from typing import Tuple, Union import torch.nn.functional as F +from einconv import index_pattern from einconv.utils import get_conv_paddings -from einops import rearrange, reduce +from einops import einsum, rearrange, reduce from torch import Tensor, cat from torch.nn import Conv2d, Linear, Module from torch.nn.modules.utils import _pair @@ -59,6 +60,64 @@ def _extract_patches( return rearrange(x_unfold, "b c_in_k1_k2 o1_o2 -> b o1_o2 c_in_k1_k2") +def _extract_averaged_patches( + x: Tensor, + kernel_size: Union[Tuple[int, int], int], + stride: Union[Tuple[int, int], int], + padding: Union[Tuple[int, int], int, str], + dilation: Union[Tuple[int, int], int], + groups: int, +) -> Tensor: + """Extract averaged patches from the input of a 2d-convolution. + + The patches are averaged over channel groups and output locations. + + Uses the tensor network formulation of convolution from + [Dangel, 2023](https://arxiv.org/abs/2307.02275). + + Args: + x: Input to a 2d-convolution. Has shape `[batch_size, C_in, I1, I2]`. + kernel_size: The convolution's kernel size supplied as 2-tuple or integer. + stride: The convolution's stride supplied as 2-tuple or integer. + padding: The convolution's padding supplied as 2-tuple, integer, or string. + dilation: The convolution's dilation supplied as 2-tuple or integer. + groups: The number of channel groups. + + Returns: + A tensor of shape `[batch_size, C_in // groups * K1 * K2]` where each column + `[b, :]` contains the flattened patch of sample `b` averaged over all output + locations and channel groups. + """ + # average channel groups + x = rearrange(x, "b (g c_in) i1 i2 -> b g c_in i1 i2", g=groups) + x = reduce(x, "b g c_in i1 i2 -> b c_in i1 i2", "mean") + + # TODO For convolutions with special structure, we don't even need to compute + # the index pattern tensors, or can resort to contracting only slices thereof. + # In order for this to work `einconv`'s TN simplification mechanism must first + # be refactored to work purely symbolically. Once this is done, it will be + # possible to do the below even more efficiently (memory and run time) for + # structured convolutions. + + # compute index pattern tensors, average output dimension + patterns = [] + for i, k, s, p, d in zip( + x.shape[-2:], + _pair(kernel_size), + _pair(stride), + (padding, padding) if isinstance(padding, str) else _pair(padding), + _pair(dilation), + ): + pi = index_pattern( + i, k, stride=s, padding=p, dilation=d, dtype=x.dtype, device=x.device + ) + pi = reduce(pi, "k o i -> k i", "mean") + patterns.append(pi) + + x = einsum(x, *patterns, "b c_in i1 i2, k1 i1, k2 i2 -> b c_in k1 k2") + return rearrange(x, "b c_in k1 k2 -> b (c_in k1 k2)") + + def process_input(x: Tensor, module: Module, kfac_approx: str) -> Tensor: """Unfold the input for convolutions, append ones if biases are present. @@ -95,20 +154,19 @@ def conv2d_process_input(x: Tensor, layer: Conv2d, kfac_approx: str) -> Tensor: Returns: The processed input. Has shape - `[batch_size, O1 * O2, C_in // groups * K1 * K2 (+ 1)]` for `"reduce"` and + `[batch_size, C_in // groups * K1 * K2 (+ 1)]` for `"reduce"` and `[batch_size * O1 * O2, C_in // groups * K1 * K2 (+ 1)]` for `"expand"`. The `+1` is active if the layer has a bias. """ - x = _extract_patches( + patch_extractor_fn = ( + _extract_patches if kfac_approx == "expand" else _extract_averaged_patches + ) + x = patch_extractor_fn( x, layer.kernel_size, layer.stride, layer.padding, layer.dilation, layer.groups ) if kfac_approx == "expand": - # KFAC-expand approximation x = rearrange(x, "b o1_o2 c_in_k1_k2 -> (b o1_o2) c_in_k1_k2") - else: - # KFAC-reduce approximation - x = reduce(x, "b o1_o2 c_in_k1_k2 -> b c_in_k1_k2", "mean") if layer.bias is not None: x = cat([x, x.new_ones(x.shape[0], 1)], dim=1) diff --git a/test/optim/test_utils.py b/test/optim/test_utils.py new file mode 100644 index 0000000..5fd3bd1 --- /dev/null +++ b/test/optim/test_utils.py @@ -0,0 +1,137 @@ +"""Test utility functions of the optimizer.""" + +from test.utils import report_nonclose +from typing import Any, Dict + +from codetiming import Timer +from einops import reduce +from memory_profiler import memory_usage +from pytest import mark +from torch import Tensor, manual_seed, rand + +from singd.optim.utils import _extract_averaged_patches, _extract_patches + +CASES = [ + { + "batch_size": 20, + "in_channels": 10, + "input_size": (28, 28), + "kernel_size": (3, 3), + "stride": (1, 1), + "padding": (1, 1), + "dilation": (1, 1), + "groups": 2, # must divide in_channels + "seed": 0, + } +] +CASE_IDS = [ + "_".join([f"{k}={v}".replace(" ", "") for k, v in case.items()]) for case in CASES +] + + +@mark.parametrize("case", CASES, ids=CASE_IDS) +def test_extract_average_patches(case: Dict[str, Any]): + """Compare averaged patches with the averaged output of patches. + + Args: + case: Dictionary of test case parameters. + """ + manual_seed(case["seed"]) + x = rand(case["batch_size"], case["in_channels"], *case["input_size"]) + + kernel_size = case["kernel_size"] + stride = case["stride"] + padding = case["padding"] + dilation = case["dilation"] + groups = case["groups"] + + patches = _extract_patches(x, kernel_size, stride, padding, dilation, groups) + truth = reduce(patches, "b o1_o2 c_in_k1_k2 -> b c_in_k1_k2", "mean") + + averaged_patches = _extract_averaged_patches( + x, kernel_size, stride, padding, dilation, groups + ) + + report_nonclose(averaged_patches, truth, rtol=1e-5, atol=1e-7) + + +MEMORY_CONSUMPTION_CASES = [ + { + "batch_size": 128, + "in_channels": 10, + "input_size": (256, 256), + "kernel_size": (5, 5), + "stride": (2, 2), + "padding": (1, 1), + "dilation": (1, 1), + "groups": 1, # must divide in_channels + "seed": 0, + } +] +MEMORY_CONSUMPTION_CASE_IDS = [ + "_".join([f"{k}={v}".replace(" ", "") for k, v in case.items()]) + for case in MEMORY_CONSUMPTION_CASES +] + + +@mark.parametrize("case", MEMORY_CONSUMPTION_CASES, ids=MEMORY_CONSUMPTION_CASE_IDS) +def test_performance_extract_average_patches(case: Dict[str, Any]): + """Compare performance of averaged patches vs averaged output of patches. + + Compares run time and memory consumption + + Args: + case: Dictionary of test case parameters. + """ + x_shape = (case["batch_size"], case["in_channels"], *case["input_size"]) + seed = case["seed"] + + kernel_size = case["kernel_size"] + stride = case["stride"] + padding = case["padding"] + dilation = case["dilation"] + groups = case["groups"] + + def inefficient_fn() -> Tensor: + """Compute average patches inefficiently. + + Returns: + Average patches. + """ + manual_seed(seed) + x = rand(*x_shape) + patches = _extract_patches(x, kernel_size, stride, padding, dilation, groups) + return reduce(patches, "b o1_o2 c_in_k1_k2 -> b c_in_k1_k2", "mean") + + def efficient_fn() -> Tensor: + """Compute average patches efficiently. + + Returns: + Average patches. + """ + manual_seed(seed) + x = rand(*x_shape) + return _extract_averaged_patches( + x, kernel_size, stride, padding, dilation, groups + ) + + # measure memory + mem_inefficient = memory_usage(inefficient_fn, interval=1e-4, max_usage=True) + mem_efficient = memory_usage(efficient_fn, interval=1e-4, max_usage=True) + print(f"Memory used by inefficient function: {mem_inefficient:.1f} MiB.") + print(f"Memory used by efficient function: {mem_efficient:.1f} MiB.") + + # measure run time + with Timer(text="Inefficient function took {:.2e} s") as timer: + inefficient_result = inefficient_fn() + t_inefficient = timer.last + with Timer(text="Efficient function took {:.2e} s") as timer: + efficient_result = efficient_fn() + t_efficient = timer.last + + # compare all performance specs + report_nonclose(inefficient_result, efficient_result, rtol=1e-5, atol=1e-7) + # NOTE This may be break for cases with small unfolded input, or if the built-in + # version of `unfold` becomes more efficient. + assert mem_efficient < mem_inefficient + assert t_efficient < t_inefficient From e6cff8f02d1a368be10817c1534fe71d87621dce Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Thu, 2 Nov 2023 18:22:55 -0400 Subject: [PATCH 2/3] [REF] Improve variable name --- test/optim/test_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/optim/test_utils.py b/test/optim/test_utils.py index 5fd3bd1..d085e80 100644 --- a/test/optim/test_utils.py +++ b/test/optim/test_utils.py @@ -55,7 +55,7 @@ def test_extract_average_patches(case: Dict[str, Any]): report_nonclose(averaged_patches, truth, rtol=1e-5, atol=1e-7) -MEMORY_CONSUMPTION_CASES = [ +PERFORMANCE_CASES = [ { "batch_size": 128, "in_channels": 10, @@ -68,13 +68,13 @@ def test_extract_average_patches(case: Dict[str, Any]): "seed": 0, } ] -MEMORY_CONSUMPTION_CASE_IDS = [ +PERFORMANCE_CASE_IDS = [ "_".join([f"{k}={v}".replace(" ", "") for k, v in case.items()]) - for case in MEMORY_CONSUMPTION_CASES + for case in PERFORMANCE_CASES ] -@mark.parametrize("case", MEMORY_CONSUMPTION_CASES, ids=MEMORY_CONSUMPTION_CASE_IDS) +@mark.parametrize("case", PERFORMANCE_CASES, ids=PERFORMANCE_CASE_IDS) def test_performance_extract_average_patches(case: Dict[str, Any]): """Compare performance of averaged patches vs averaged output of patches. From 75070f1322fd71688a94bae6a43e08df1bf8dde6 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Fri, 3 Nov 2023 09:23:05 -0400 Subject: [PATCH 3/3] [REF] Improve name of input sizes --- singd/optim/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/singd/optim/utils.py b/singd/optim/utils.py index 4ea8caf..79fa649 100644 --- a/singd/optim/utils.py +++ b/singd/optim/utils.py @@ -101,8 +101,9 @@ def _extract_averaged_patches( # compute index pattern tensors, average output dimension patterns = [] + input_sizes = x.shape[-2:] for i, k, s, p, d in zip( - x.shape[-2:], + input_sizes, _pair(kernel_size), _pair(stride), (padding, padding) if isinstance(padding, str) else _pair(padding),