Skip to content

Commit

Permalink
[REF] Average channel groups before unfolding (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel authored Nov 2, 2023
1 parent e045f6e commit 125ff06
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions singd/optim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ def _extract_patches(
padding_as_int.append(p_left)
padding = tuple(padding_as_int)

# average channel groups
x = rearrange(x, "b (g c_in) i1 i2 -> b g c_in i1 i2", g=groups)
x = reduce(x, "b g c_in i1 i2 -> b c_in i1 i2", "mean")

x_unfold = F.unfold(
x, kernel_size, dilation=dilation, padding=padding, stride=stride
)
# separate the channel groups
x_unfold = rearrange(
x_unfold, "b (g c_in_k1_k2) o1_o2 -> b g c_in_k1_k2 o1_o2", g=groups
)
return reduce(x_unfold, "b g c_in_k1_k2 o1_o2 -> b o1_o2 c_in_k1_k2", "mean")
return rearrange(x_unfold, "b c_in_k1_k2 o1_o2 -> b o1_o2 c_in_k1_k2")


def process_input(x: Tensor, module: Module, kfac_approx: str) -> Tensor:
Expand Down

0 comments on commit 125ff06

Please sign in to comment.