Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] Use TN formulation of Dangel, 2023 to compute average patches #61

Merged
merged 3 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ test =
torchvision
matplotlib # visual structure tests
imageio # visual structure tests
memory_profiler # for measuring memory consumption on CPU
codetiming # for measuring run time

# Dependencies needed to run the tests (semicolon/line-separated)
lint =
Expand Down
73 changes: 66 additions & 7 deletions singd/optim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from typing import Tuple, Union

import torch.nn.functional as F
from einconv import index_pattern
from einconv.utils import get_conv_paddings
from einops import rearrange, reduce
from einops import einsum, rearrange, reduce
from torch import Tensor, cat
from torch.nn import Conv2d, Linear, Module
from torch.nn.modules.utils import _pair
Expand Down Expand Up @@ -59,6 +60,65 @@ def _extract_patches(
return rearrange(x_unfold, "b c_in_k1_k2 o1_o2 -> b o1_o2 c_in_k1_k2")


def _extract_averaged_patches(
x: Tensor,
kernel_size: Union[Tuple[int, int], int],
stride: Union[Tuple[int, int], int],
padding: Union[Tuple[int, int], int, str],
dilation: Union[Tuple[int, int], int],
groups: int,
) -> Tensor:
"""Extract averaged patches from the input of a 2d-convolution.

The patches are averaged over channel groups and output locations.

Uses the tensor network formulation of convolution from
[Dangel, 2023](https://arxiv.org/abs/2307.02275).

Args:
x: Input to a 2d-convolution. Has shape `[batch_size, C_in, I1, I2]`.
kernel_size: The convolution's kernel size supplied as 2-tuple or integer.
stride: The convolution's stride supplied as 2-tuple or integer.
padding: The convolution's padding supplied as 2-tuple, integer, or string.
dilation: The convolution's dilation supplied as 2-tuple or integer.
groups: The number of channel groups.

Returns:
A tensor of shape `[batch_size, C_in // groups * K1 * K2]` where each column
`[b, :]` contains the flattened patch of sample `b` averaged over all output
locations and channel groups.
"""
# average channel groups
x = rearrange(x, "b (g c_in) i1 i2 -> b g c_in i1 i2", g=groups)
x = reduce(x, "b g c_in i1 i2 -> b c_in i1 i2", "mean")

# TODO For convolutions with special structure, we don't even need to compute
# the index pattern tensors, or can resort to contracting only slices thereof.
# In order for this to work `einconv`'s TN simplification mechanism must first
# be refactored to work purely symbolically. Once this is done, it will be
# possible to do the below even more efficiently (memory and run time) for
# structured convolutions.

# compute index pattern tensors, average output dimension
patterns = []
input_sizes = x.shape[-2:]
for i, k, s, p, d in zip(
input_sizes,
_pair(kernel_size),
_pair(stride),
(padding, padding) if isinstance(padding, str) else _pair(padding),
_pair(dilation),
):
pi = index_pattern(
i, k, stride=s, padding=p, dilation=d, dtype=x.dtype, device=x.device
)
pi = reduce(pi, "k o i -> k i", "mean")
patterns.append(pi)

x = einsum(x, *patterns, "b c_in i1 i2, k1 i1, k2 i2 -> b c_in k1 k2")
return rearrange(x, "b c_in k1 k2 -> b (c_in k1 k2)")


def process_input(x: Tensor, module: Module, kfac_approx: str) -> Tensor:
"""Unfold the input for convolutions, append ones if biases are present.

Expand Down Expand Up @@ -95,20 +155,19 @@ def conv2d_process_input(x: Tensor, layer: Conv2d, kfac_approx: str) -> Tensor:

Returns:
The processed input. Has shape
`[batch_size, O1 * O2, C_in // groups * K1 * K2 (+ 1)]` for `"reduce"` and
`[batch_size, C_in // groups * K1 * K2 (+ 1)]` for `"reduce"` and
`[batch_size * O1 * O2, C_in // groups * K1 * K2 (+ 1)]` for `"expand"`.
The `+1` is active if the layer has a bias.
"""
x = _extract_patches(
patch_extractor_fn = (
_extract_patches if kfac_approx == "expand" else _extract_averaged_patches
)
x = patch_extractor_fn(
x, layer.kernel_size, layer.stride, layer.padding, layer.dilation, layer.groups
)

if kfac_approx == "expand":
# KFAC-expand approximation
x = rearrange(x, "b o1_o2 c_in_k1_k2 -> (b o1_o2) c_in_k1_k2")
else:
# KFAC-reduce approximation
x = reduce(x, "b o1_o2 c_in_k1_k2 -> b c_in_k1_k2", "mean")

if layer.bias is not None:
x = cat([x, x.new_ones(x.shape[0], 1)], dim=1)
Expand Down
137 changes: 137 additions & 0 deletions test/optim/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""Test utility functions of the optimizer."""

from test.utils import report_nonclose
from typing import Any, Dict

from codetiming import Timer
from einops import reduce
from memory_profiler import memory_usage
from pytest import mark
from torch import Tensor, manual_seed, rand

from singd.optim.utils import _extract_averaged_patches, _extract_patches

CASES = [
{
"batch_size": 20,
"in_channels": 10,
"input_size": (28, 28),
"kernel_size": (3, 3),
"stride": (1, 1),
"padding": (1, 1),
"dilation": (1, 1),
"groups": 2, # must divide in_channels
"seed": 0,
}
]
CASE_IDS = [
"_".join([f"{k}={v}".replace(" ", "") for k, v in case.items()]) for case in CASES
]


@mark.parametrize("case", CASES, ids=CASE_IDS)
def test_extract_average_patches(case: Dict[str, Any]):
"""Compare averaged patches with the averaged output of patches.

Args:
case: Dictionary of test case parameters.
"""
manual_seed(case["seed"])
x = rand(case["batch_size"], case["in_channels"], *case["input_size"])

kernel_size = case["kernel_size"]
stride = case["stride"]
padding = case["padding"]
dilation = case["dilation"]
groups = case["groups"]

patches = _extract_patches(x, kernel_size, stride, padding, dilation, groups)
truth = reduce(patches, "b o1_o2 c_in_k1_k2 -> b c_in_k1_k2", "mean")

averaged_patches = _extract_averaged_patches(
x, kernel_size, stride, padding, dilation, groups
)

report_nonclose(averaged_patches, truth, rtol=1e-5, atol=1e-7)


PERFORMANCE_CASES = [
{
"batch_size": 128,
"in_channels": 10,
"input_size": (256, 256),
"kernel_size": (5, 5),
"stride": (2, 2),
"padding": (1, 1),
"dilation": (1, 1),
"groups": 1, # must divide in_channels
"seed": 0,
}
]
PERFORMANCE_CASE_IDS = [
"_".join([f"{k}={v}".replace(" ", "") for k, v in case.items()])
for case in PERFORMANCE_CASES
]


@mark.parametrize("case", PERFORMANCE_CASES, ids=PERFORMANCE_CASE_IDS)
def test_performance_extract_average_patches(case: Dict[str, Any]):
"""Compare performance of averaged patches vs averaged output of patches.

Compares run time and memory consumption

Args:
case: Dictionary of test case parameters.
"""
x_shape = (case["batch_size"], case["in_channels"], *case["input_size"])
seed = case["seed"]

kernel_size = case["kernel_size"]
stride = case["stride"]
padding = case["padding"]
dilation = case["dilation"]
groups = case["groups"]

def inefficient_fn() -> Tensor:
"""Compute average patches inefficiently.

Returns:
Average patches.
"""
manual_seed(seed)
x = rand(*x_shape)
patches = _extract_patches(x, kernel_size, stride, padding, dilation, groups)
return reduce(patches, "b o1_o2 c_in_k1_k2 -> b c_in_k1_k2", "mean")

def efficient_fn() -> Tensor:
"""Compute average patches efficiently.

Returns:
Average patches.
"""
manual_seed(seed)
x = rand(*x_shape)
return _extract_averaged_patches(
x, kernel_size, stride, padding, dilation, groups
)

# measure memory
mem_inefficient = memory_usage(inefficient_fn, interval=1e-4, max_usage=True)
mem_efficient = memory_usage(efficient_fn, interval=1e-4, max_usage=True)
print(f"Memory used by inefficient function: {mem_inefficient:.1f} MiB.")
print(f"Memory used by efficient function: {mem_efficient:.1f} MiB.")

# measure run time
with Timer(text="Inefficient function took {:.2e} s") as timer:
inefficient_result = inefficient_fn()
t_inefficient = timer.last
with Timer(text="Efficient function took {:.2e} s") as timer:
efficient_result = efficient_fn()
t_efficient = timer.last

# compare all performance specs
report_nonclose(inefficient_result, efficient_result, rtol=1e-5, atol=1e-7)
# NOTE This may be break for cases with small unfolded input, or if the built-in
# version of `unfold` becomes more efficient.
assert mem_efficient < mem_inefficient
assert t_efficient < t_inefficient
Loading