Skip to content

Commit

Permalink
redefining fourier and test automation
Browse files Browse the repository at this point in the history
  • Loading branch information
Jorgedavyd committed May 31, 2024
1 parent db29944 commit d172c33
Show file tree
Hide file tree
Showing 7 changed files with 305 additions and 86 deletions.
8 changes: 4 additions & 4 deletions docs/api/nn.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,8 @@ sample_input: Tensor = torch.randn(32, 10) # batch size, input_size
model = MonteCarloFC(
fc_layer = DeepNeuralNetwork(
in_features = 10,
(20, 20, 1),
(nn.ReLU(), nn.ReLU(), nn.Sigmoid())
layers = (20, 20, 1),
activations = (nn.ReLU(), nn.ReLU(), nn.Sigmoid())
),
dropout = 0.5,
n_sampling = 50
Expand All @@ -274,7 +274,7 @@ sample_input: Tensor = torch.randn(32, 20, 10) # batch size, sequence_length, in

norm = RootMeanSquaredNormalization(dim = 10)

model(sample_input) #-> output (32, 20, 10)
norm(sample_input) #-> output (32, 20, 10)

```

Expand All @@ -295,7 +295,7 @@ from lightorch.nn.partial import PartialConv2d
from torch import nn, Tensor

sample_input: Tensor = torch.randn(32, 3, 256, 256) # batch size, channels, height, width
mask_in: Tensor = sample_input()
mask_in: Tensor = ...

model = nn.Sequential(
PartialConv2d(in_channels = 3, out_channels = 5, 3, 1, 1),
Expand Down
6 changes: 5 additions & 1 deletion lightorch/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from .fourier import *
from .transformer.attention import *
from .transformer import *
from .partial import *
from .normalization import *
from .criterions import *
from .complex import *
from .dnn import *
from .kan import *
from .monte_carlo import *
from .utils import *
88 changes: 40 additions & 48 deletions lightorch/nn/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,39 @@
from torch.nn import init
from math import sqrt
import torch.nn.functional as f
from typing import Tuple

from typing import Tuple, Sequence
from itertools import chain

class _FourierConvNd(nn.Module):
def __init__(
self,
n: int,
in_channels: int,
out_channels: int,
*kernel_size,
padding: Tuple[int],
kernel_size: Tuple[int, ...] | int,
padding: Tuple[int, ...] | int,
bias: bool = True,
eps: float = 1e-5,
pre_fft: bool = True,
post_ifft: bool = False,
device=None,
dtype=None,
) -> None:
super().__init__()
self.n = n
if isinstance(kernel_size, tuple):
assert (n == n), f'Not valid kernel size for {n}-convolution'
else:
kernel_size = (kernel_size, )*n

super().__init__()
self.factory_kwargs = {"device": device, "dtype": dtype}
self.padding = padding
self.padding = self.get_padding(padding)
if pre_fft:
self.fft = lambda x: fftn(x, dim=(-i for i in range(1, len(kernel_size))))
self.fft = lambda x: fftn(x, dim=(-i for i in range(1, n + 1)))
else:
self.fft = False
if post_ifft:
self.ifft = lambda x: ifftn(x, dim=(-i for i in range(1, len(kernel_size))))
self.ifft = lambda x: ifftn(x, dim=(-i for i in range(1, n + 1)))
else:
self.ifft = False

Expand All @@ -51,12 +57,19 @@ def __init__(
self.bias = None

self._init_parameters()
self._fourier_space(len(kernel_size))

def _fourier_space(self, dims: int) -> Tensor:

def get_padding(self, padding: Tuple[int, ...] | int) -> Sequence[int]:
if isinstance(padding, tuple):
assert(len(padding) == self.n), f'Not valid padding scheme for {self.n}-convolution'
return tuple(*chain.from_iterable([(i, )*2 for i in reversed(padding)]))
else:
return tuple(*chain.from_iterable([(padding, )*2 for _ in range(self.n)]))

def _fourier_space(self) -> Tensor:
# probably deprecated
if self.bias is not None:
self.bias = self.fft(self.bias, dim=(-i for i in range(1, dims)))
self.weight = self.fft(self.weight, dim=(-i for i in range(1, dims)))
self.bias = self.fft(self.bias, dim=(-i for i in range(1, self.n + 1)))
self.weight = self.fft(self.weight, dim=(-i for i in range(1, self.n + 1)))

def _init_parameters(self) -> None:
init.kaiming_uniform_(self.weight, a=sqrt(5))
Expand All @@ -68,34 +81,13 @@ def _init_parameters(self) -> None:


class _FourierDeconvNd(_FourierConvNd):
def __init__(
self,
in_channels: int,
out_channels: int,
*kernel_size,
bias: bool = True,
eps: float = 0.00001,
pre_fft: bool = True,
post_ifft: bool = False,
device=None,
dtype=None,
) -> None:
super().__init__(
in_channels,
out_channels,
*kernel_size,
bias=bias,
eps=eps,
pre_fft=pre_fft,
post_ifft=post_ifft,
device=device,
dtype=dtype,
)
def __init__(self, n: int, in_channels: int, out_channels: int, kernel_size: Tuple[int] | int, padding: Tuple[int], bias: bool = True, eps: float = 0.00001, pre_fft: bool = True, post_ifft: bool = False, device=None, dtype=None) -> None:
super().__init__(n, in_channels, out_channels, kernel_size, padding, bias, eps, pre_fft, post_ifft, device, dtype)


class FourierConv1d(_FourierConvNd):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
super().__init__(1, *args, **kwargs)

def forward(self, input: Tensor) -> Tensor:
if self.fft:
Expand All @@ -107,7 +99,7 @@ def forward(self, input: Tensor) -> Tensor:
f.pad(input, self.padding, mode="constant", value=0),
self.one,
self.weight,
self.bias,
self.bias
)
if self.ifft:
return self.ifft(out)
Expand All @@ -116,7 +108,7 @@ def forward(self, input: Tensor) -> Tensor:

class FourierConv2d(_FourierConvNd):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
super().__init__(2, *args, **kwargs)

def forward(self, input: Tensor) -> Tensor:
if self.fft:
Expand All @@ -128,7 +120,7 @@ def forward(self, input: Tensor) -> Tensor:
f.pad(input, self.padding, "constant", value=0),
self.one,
self.weight,
self.bias,
self.bias
)

if self.ifft:
Expand All @@ -138,7 +130,7 @@ def forward(self, input: Tensor) -> Tensor:

class FourierConv3d(_FourierConvNd):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
super().__init__(3, *args, **kwargs)

def forward(self, input: Tensor) -> Tensor:
if self.fft:
Expand All @@ -150,7 +142,7 @@ def forward(self, input: Tensor) -> Tensor:
f.pad(input, self.padding, "constant", value=0),
self.one,
self.weight,
self.bias,
self.bias
)
if self.ifft:
return self.ifft(out)
Expand All @@ -159,7 +151,7 @@ def forward(self, input: Tensor) -> Tensor:

class FourierDeconv1d(_FourierDeconvNd):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
super().__init__(1, *args, **kwargs)

def forward(self, input: Tensor) -> Tensor:
if self.fft:
Expand All @@ -171,7 +163,7 @@ def forward(self, input: Tensor) -> Tensor:
f.pad(input, self.padding, "constant", value=0),
self.one,
self.weight,
self.bias,
self.bias
)
if self.ifft:
return self.ifft(out)
Expand All @@ -180,7 +172,7 @@ def forward(self, input: Tensor) -> Tensor:

class FourierDeconv2d(_FourierDeconvNd):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
super().__init__(2, *args, **kwargs)

def forward(self, input: Tensor) -> Tensor:
if self.fft:
Expand All @@ -192,7 +184,7 @@ def forward(self, input: Tensor) -> Tensor:
f.pad(input, self.padding, "constant", value=0),
self.one,
self.weight,
self.bias,
self.bias
)
if self.ifft:
return self.ifft(out)
Expand All @@ -201,7 +193,7 @@ def forward(self, input: Tensor) -> Tensor:

class FourierDeconv3d(_FourierDeconvNd):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
super().__init__(3, *args, **kwargs)

def forward(self, input: Tensor) -> Tensor:
if self.fft:
Expand All @@ -213,7 +205,7 @@ def forward(self, input: Tensor) -> Tensor:
f.pad(input, self.padding, "constant", value=0),
self.one,
self.weight,
self.bias,
self.bias
)
if self.ifft:
return self.ifft(out)
Expand Down
22 changes: 16 additions & 6 deletions lightorch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,36 @@
import torch.nn.functional as F
from lightning.pytorch import LightningModule
from einops import rearrange
from torch.fft import fftn


def _fourierconvNd(x: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor:
def _fourierconvNd(
n: int, x: Tensor, weight: Tensor, bias: Tensor | None, eps: float = 1e-5
) -> Tensor:
# To fourier space
weight = fftn(weight, dim = (-i for i in range(1, n+1)))

# weight -> 1, 1, out channels, *kernel_size
x *= weight.reshape(1, 1, *weight.shape) # Convolution in the fourier space
x *= weight.reshape(1, 1, *weight.shape) + eps # Convolution in the fourier space

if bias is not None:
return x + bias.reshape(1, 1, -1, 1, 1)
bias = fftn(bias, dim = (-i for i in range(1, n+1)))
return x + bias.reshape(1, 1, -1, *[1 for _ in range(n)])

return x


def _fourierdeconvNd(
x: Tensor, weight: Tensor, bias: Tensor | None, eps: float = 1e-5
n: int, x: Tensor, weight: Tensor, bias: Tensor | None, eps: float = 1e-5
) -> Tensor:
# To fourier space
weight = fftn(weight, dim = (-i for i in range(1, n+1)))

# weight -> 1, 1, out channels, *kernel_size
x /= weight.reshape(1, 1, *weight.shape) + eps # Convolution in the fourier space

if bias is not None:
return x + bias.reshape(1, 1, -1, 1, 1)
bias = fftn(bias, dim = (-i for i in range(1, n+1)))
return x + bias.reshape(1, 1, -1, *[1 for _ in range(n)])

return x

Expand Down
96 changes: 96 additions & 0 deletions lightorch/nn/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from torch import nn, Tensor
from typing import List, Sequence
from torchvision.models import (
VGG19_Weights, vgg19,
VGG16_Weights, vgg16,
ResNet50_Weights, resnet50,
resnet101, ResNet101_Weights
)


VALID_MODELS = {
'vgg19': {
'model': vgg19,
'weights': VGG19_Weights,
'valid_layers': [
i for i in range(37)
]
},
'vgg16': {
'model': vgg16,
'weights': VGG16_Weights,
'valid_layers': [
i for i in range(31)
]
},
'resnet50': {
'model': resnet50,
'weights': ResNet50_Weights,
'valid_layers': [
'conv1',
'bn1',
'relu',
'maxpool',
'layer1',
'layer2',
'layer3',
'layer4',
'avgpool',
'fc',
]
},
'resnet101': {
'model': resnet101,
'weights': ResNet101_Weights,
'valid_layers': [
'conv1',
'bn1',
'relu',
'maxpool',
'layer1',
'layer2',
'layer3',
'layer4',
'avgpool',
'fc',
]
},
}


class FeatureExtractor(nn.Module):
def __init__(self, layers: Sequence[int] = [4,9,18], model_str: str = 'vgg19') -> None:
assert(model_str in VALID_MODELS), f'Model not in {VALID_MODELS.keys()}'
assert (list(set(layers)) == layers), 'Not valid repeated inputs'
hist: List = []
for layer in layers:
valid_models: List[str] = VALID_MODELS[model_str]['valid_layers']
num = valid_models.index(layer)
hist.append(num)
assert (sorted(hist) == hist), 'Not ordered inputs'
super().__init__()
self.model_str: str = model_str
self.layers = list(map(str, layers))
self.model = VALID_MODELS[model_str](weights = VALID_MODELS[model_str][1].IMAGENET1K_V1)
for param in self.model.parameters():
param.requires_grad = False
# Setting the transformation
self.transform = VALID_MODELS[model_str][1].IMAGENET1K_V1.transforms(antialias=True)

def forward(self, input: Tensor) -> List[Tensor]:
features = []

if 'vgg' in self.model_str:
for name, layer in self.model.features.named_children():
input = layer(input)
if name in self.layers:
features.append(input)
if name == self.layers[-1]:
return features
else:
for name, layer in self.model.named_children():
input = layer(input)
if name in self.layers:
features.append(input)
if name == self.layers[-1]:
return features
Loading

0 comments on commit d172c33

Please sign in to comment.