Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Jorgedavyd committed May 31, 2024
2 parents f9773bf + ba72e5a commit 9b934a3
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 74 deletions.
2 changes: 1 addition & 1 deletion lightorch/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
from .dnn import *
from .kan import *
from .monte_carlo import *
from .utils import *
from .utils import *
48 changes: 37 additions & 11 deletions lightorch/nn/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Tuple, Sequence
from itertools import chain


class _FourierConvNd(nn.Module):
def __init__(
self,
Expand All @@ -25,9 +26,9 @@ def __init__(
) -> None:
self.n = n
if isinstance(kernel_size, tuple):
assert (n == n), f'Not valid kernel size for {n}-convolution'
assert n == n, f"Not valid kernel size for {n}-convolution"
else:
kernel_size = (kernel_size, )*n
kernel_size = (kernel_size,) * n

super().__init__()
self.factory_kwargs = {"device": device, "dtype": dtype}
Expand Down Expand Up @@ -62,7 +63,7 @@ def __init__(
self.bias = None

self._init_parameters()

def get_padding(self, padding: Tuple[int, ...] | int) -> Sequence[int]:
if isinstance(padding, tuple):
assert(len(padding) == self.n), f'Not valid padding scheme for {self.n}-convolution'
Expand All @@ -86,8 +87,33 @@ def _init_parameters(self) -> None:


class _FourierDeconvNd(_FourierConvNd):
def __init__(self, n: int, in_channels: int, out_channels: int, kernel_size: Tuple[int] | int, padding: Tuple[int], bias: bool = True, eps: float = 0.00001, pre_fft: bool = True, post_ifft: bool = False, device=None, dtype=None) -> None:
super().__init__(n, in_channels, out_channels, kernel_size, padding, bias, eps, pre_fft, post_ifft, device, dtype)
def __init__(
self,
n: int,
in_channels: int,
out_channels: int,
kernel_size: Tuple[int] | int,
padding: Tuple[int],
bias: bool = True,
eps: float = 0.00001,
pre_fft: bool = True,
post_ifft: bool = False,
device=None,
dtype=None,
) -> None:
super().__init__(
n,
in_channels,
out_channels,
kernel_size,
padding,
bias,
eps,
pre_fft,
post_ifft,
device,
dtype,
)


class FourierConv1d(_FourierConvNd):
Expand All @@ -102,7 +128,7 @@ def forward(self, input: Tensor) -> Tensor:
f.pad(input, self.padding, mode="constant", value=0),
self.one,
self.weight,
self.bias
self.bias,
)
else:
out = F.fourierconv1d(input, self.one, self.weight, self.bias)
Expand All @@ -122,7 +148,7 @@ def forward(self, input: Tensor) -> Tensor:
f.pad(input, self.padding, "constant", value=0),
self.one,
self.weight,
self.bias
self.bias,
)
else:
out = F.fourierconv2d(input, self.one, self.weight, self.bias)
Expand All @@ -145,7 +171,7 @@ def forward(self, input: Tensor) -> Tensor:
f.pad(input, self.padding, "constant", value=0),
self.one,
self.weight,
self.bias
self.bias,
)
else:
out = F.fourierconv3d(input, self.one, self.weight, self.bias)
Expand All @@ -167,7 +193,7 @@ def forward(self, input: Tensor) -> Tensor:
f.pad(input, self.padding, "constant", value=0),
self.one,
self.weight,
self.bias
self.bias,
)
else:
out = F.fourierdeconv1d(input, self.one, self.weight, self.bias)
Expand All @@ -188,7 +214,7 @@ def forward(self, input: Tensor) -> Tensor:
f.pad(input, self.padding, "constant", value=0),
self.one,
self.weight,
self.bias
self.bias,
)
else:
out = F.fourierdeconv2d(input, self.one, self.weight, self.bias)
Expand All @@ -209,7 +235,7 @@ def forward(self, input: Tensor) -> Tensor:
f.pad(input, self.padding, "constant", value=0),
self.one,
self.weight,
self.bias
self.bias,
)
else:
out = F.fourierdeconv3d(input, self.one, self.weight, self.bias)
Expand Down
1 change: 1 addition & 0 deletions lightorch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def _fourierconvNd(

return x


def _fourierdeconvNd(
n: int, x: Tensor, weight: Tensor, bias: Tensor | None, eps: float = 1e-5
) -> Tensor:
Expand Down
112 changes: 57 additions & 55 deletions lightorch/nn/utils.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,75 @@
from torch import nn, Tensor
from typing import List, Sequence
from torchvision.models import (
VGG19_Weights, vgg19,
VGG16_Weights, vgg16,
ResNet50_Weights, resnet50,
resnet101, ResNet101_Weights
VGG19_Weights,
vgg19,
VGG16_Weights,
vgg16,
ResNet50_Weights,
resnet50,
resnet101,
ResNet101_Weights,
)


VALID_MODELS = {
'vgg19': {
'model': vgg19,
'weights': VGG19_Weights,
'valid_layers': [
i for i in range(37)
]
"vgg19": {
"model": vgg19,
"weights": VGG19_Weights,
"valid_layers": [i for i in range(37)],
},
'vgg16': {
'model': vgg16,
'weights': VGG16_Weights,
'valid_layers': [
i for i in range(31)
]
"vgg16": {
"model": vgg16,
"weights": VGG16_Weights,
"valid_layers": [i for i in range(31)],
},
"resnet50": {
"model": resnet50,
"weights": ResNet50_Weights,
"valid_layers": [
"conv1",
"bn1",
"relu",
"maxpool",
"layer1",
"layer2",
"layer3",
"layer4",
"avgpool",
"fc",
],
},
"resnet101": {
"model": resnet101,
"weights": ResNet101_Weights,
"valid_layers": [
"conv1",
"bn1",
"relu",
"maxpool",
"layer1",
"layer2",
"layer3",
"layer4",
"avgpool",
"fc",
],
},
'resnet50': {
'model': resnet50,
'weights': ResNet50_Weights,
'valid_layers': [
'conv1',
'bn1',
'relu',
'maxpool',
'layer1',
'layer2',
'layer3',
'layer4',
'avgpool',
'fc',
]
},
'resnet101': {
'model': resnet101,
'weights': ResNet101_Weights,
'valid_layers': [
'conv1',
'bn1',
'relu',
'maxpool',
'layer1',
'layer2',
'layer3',
'layer4',
'avgpool',
'fc',
]
},
}


class FeatureExtractor(nn.Module):
def __init__(self, layers: Sequence[int] = [4,9,18], model_str: str = 'vgg19') -> None:
assert(model_str in VALID_MODELS), f'Model not in {VALID_MODELS.keys()}'
assert (list(set(layers)) == layers), 'Not valid repeated inputs'
def __init__(
self, layers: Sequence[int] = [4, 9, 18], model_str: str = "vgg19"
) -> None:
assert model_str in VALID_MODELS, f"Model not in {VALID_MODELS.keys()}"
assert list(set(layers)) == layers, "Not valid repeated inputs"
hist: List = []
for layer in layers:
valid_models: List[str] = VALID_MODELS[model_str]['valid_layers']
valid_models: List[str] = VALID_MODELS[model_str]["valid_layers"]
num = valid_models.index(layer)
hist.append(num)
assert (sorted(hist) == hist), 'Not ordered inputs'
assert sorted(hist) == hist, "Not ordered inputs"
super().__init__()
self.model_str: str = model_str
self.layers = list(map(str, layers))
Expand All @@ -79,8 +81,8 @@ def __init__(self, layers: Sequence[int] = [4,9,18], model_str: str = 'vgg19') -

def forward(self, input: Tensor) -> List[Tensor]:
features = []
if 'vgg' in self.model_str:

if "vgg" in self.model_str:
for name, layer in self.model.features.named_children():
input = layer(input)
if name in self.layers:
Expand Down
20 changes: 13 additions & 7 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_complex() -> None:
sample_input = torch.randn(32, 10) + 1j * torch.randn(32, 10)
layer = Complex(nn.Linear(10, 20))
result = layer(sample_input)
assert result is not None, 'Complex failed'
assert result is not None, "Complex failed"


def test_tv() -> None:
Expand All @@ -23,7 +23,8 @@ def test_tv() -> None:
result = loss(input=input)
assert result is not None, "TV loss failed"

#Integrated

# Integrated
def test_style() -> None:
input: Tensor = torch.randn(1, 3, 256, 256)
target: Tensor = torch.randn(1, 3, 256, 256)
Expand All @@ -33,7 +34,7 @@ def test_style() -> None:
assert result is not None, "StyleLoss failed"


#Integrated
# Integrated
def test_perc() -> None:
input: Tensor = torch.randn(1, 3, 256, 256)
target: Tensor = torch.randn(1, 3, 256, 256)
Expand All @@ -57,6 +58,7 @@ def test_entropy_loss() -> None:
result = loss(input=input, target=target)
assert result is not None, "CrossEntropy failed"


def test_psnr() -> None:
input: Tensor = torch.randn(1, 3, 256, 256)
target: Tensor = torch.randn(1, 3, 256, 256)
Expand Down Expand Up @@ -171,7 +173,9 @@ def test_fourier1d() -> None:
assert output.shape == (32, 3, 10), "FourierDeconv1d failed"

def test_fourier3d() -> None:
sample_input: Tensor = torch.randn(32, 3, 5, 256, 256) # batch size, channels, frames, height, width
sample_input: Tensor = torch.randn(
32, 3, 5, 256, 256
) # batch size, channels, frames, height, width
model = nn.Sequential(
FourierConv3d(
3,
Expand Down Expand Up @@ -217,11 +221,14 @@ def test_partial() -> None:


def test_normalization() -> None:
sample_input: Tensor = torch.randn(32, 20, 10) # batch size, sequence_length, input_size
sample_input: Tensor = torch.randn(
32, 20, 10
) # batch size, sequence_length, input_size
norm = RootMeanSquaredNormalization(dim=10)
output = norm(sample_input)
assert output.shape == (32, 20, 10), "RootMeanSquaredNormalization failed"


# Integrated
def test_monte_carlo() -> None:
sample_input: Tensor = torch.randn(32, 10) # batch size, input_size
Expand All @@ -232,13 +239,12 @@ def test_monte_carlo() -> None:
activations=(nn.ReLU, nn.ReLU, nn.Sigmoid)
),
dropout=0.5,
n_sampling=50
n_sampling=50,
)
output = model(sample_input)
assert output.shape == (32, 1), "MonteCarloFC failed"



def test_kan() -> None:
# Placeholder for future implementation
raise NotImplementedError("KAN test not implemented")
2 changes: 2 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.utils.data import Dataset, DataLoader
import random


def create_inputs(*size) -> Tensor:
return torch.randn(*size)

Expand Down Expand Up @@ -52,6 +53,7 @@ def val_dataloader(self) -> DataLoader:
pin_memory=self.pin_memory,
)


def create_mask():
# Create a rectangle
n = random.randint(1, 31) # Random integer for n
Expand Down

0 comments on commit 9b934a3

Please sign in to comment.