Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Jorgedavyd committed Jun 10, 2024
1 parent 8b14d60 commit 797c017
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 12 deletions.
2 changes: 1 addition & 1 deletion lightorch/nn/dnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ def forward(self, input: Tensor) -> Tensor:
return self.dnn(input)


__all__ = ["DeepNeuralNetwork"]
__all__ = ["DeepNeuralNetwork"]
108 changes: 101 additions & 7 deletions lightorch/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)


VALID_MODELS = {
VALID_MODELS_2D = {
"vgg19": {
"model": vgg19,
"weights": VGG19_Weights,
Expand Down Expand Up @@ -58,28 +58,28 @@
}


class FeatureExtractor(nn.Module):
class FeatureExtractor2D(nn.Module):
def __init__(
self, layers: Sequence[int] = [4, 9, 18], model_str: str = "vgg19"
) -> None:
assert model_str in VALID_MODELS, f"Model not in {VALID_MODELS.keys()}"
assert model_str in VALID_MODELS_2D, f"Model not in {VALID_MODELS_2D.keys()}"
assert list(set(layers)) == layers, "Not valid repeated inputs"
hist: List = []
for layer in layers:
valid_models: List[str] = VALID_MODELS[model_str]["valid_layers"]
valid_models: List[str] = VALID_MODELS_2D[model_str]["valid_layers"]
num = valid_models.index(layer)
hist.append(num)
assert sorted(hist) == hist, "Not ordered inputs"
super().__init__()
self.model_str: str = model_str
self.layers = list(map(str, layers))
self.model = VALID_MODELS[model_str]["model"](
weights=VALID_MODELS[model_str]["weights"].IMAGENET1K_V1
self.model = VALID_MODELS_2D[model_str]["model"](
weights=VALID_MODELS_2D[model_str]["weights"].IMAGENET1K_V1
)
for param in self.model.parameters():
param.requires_grad = False
# Setting the transformation
self.transform = VALID_MODELS[model_str]["weights"].IMAGENET1K_V1.transforms(
self.transform = VALID_MODELS_2D[model_str]["weights"].IMAGENET1K_V1.transforms(
antialias=True
)

Expand All @@ -100,3 +100,97 @@ def forward(self, input: Tensor) -> List[Tensor]:
features.append(input)
if name == self.layers[-1]:
return features

VALID_MODELS_3D = {
"vgg19": {
"model": vgg19,
"weights": VGG19_Weights,
"valid_layers": [i for i in range(37)],
},
"vgg16": {
"model": vgg16,
"weights": VGG16_Weights,
"valid_layers": [i for i in range(31)],
},
"resnet50": {
"model": resnet50,
"weights": ResNet50_Weights,
"valid_layers": [
"conv1",
"bn1",
"relu",
"maxpool",
"layer1",
"layer2",
"layer3",
"layer4",
"avgpool",
"fc",
],
},
"resnet101": {
"model": resnet101,
"weights": ResNet101_Weights,
"valid_layers": [
"conv1",
"bn1",
"relu",
"maxpool",
"layer1",
"layer2",
"layer3",
"layer4",
"avgpool",
"fc",
],
},
}


class FeatureExtractor3D(nn.Module):
def __init__(
self, layers: Sequence[int] = [4, 9, 18], model_str: str = "vgg19"
) -> None:
assert model_str in VALID_MODELS_3D, f"Model not in {VALID_MODELS_3D.keys()}"
assert list(set(layers)) == layers, "Not valid repeated inputs"
hist: List = []
for layer in layers:
valid_models: List[str] = VALID_MODELS_3D[model_str]["valid_layers"]
num = valid_models.index(layer)
hist.append(num)
assert sorted(hist) == hist, "Not ordered inputs"
super().__init__()
self.model_str: str = model_str
self.layers = list(map(str, layers))
self.model = VALID_MODELS_3D[model_str]["model"](
weights=VALID_MODELS_3D[model_str]["weights"].IMAGENET1K_V1
)
for param in self.model.parameters():
param.requires_grad = False
# Setting the transformation
self.transform = VALID_MODELS_3D[model_str]["weights"].IMAGENET1K_V1.transforms(
antialias=True
)

def forward(self, input: Tensor) -> List[Tensor]:
features = []

if "vgg" in self.model_str:
for name, layer in self.model.features.named_children():
input = layer(input)
if name in self.layers:
features.append(input)
if name == self.layers[-1]:
return features
else:
for name, layer in self.model.named_children():
input = layer(input)
if name in self.layers:
features.append(input)
if name == self.layers[-1]:
return features

__all__ = [
'FeatureExtractor2D',
'FeatureExtractor3D',
]
41 changes: 37 additions & 4 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_tv() -> None:
def test_style() -> None:
input: Tensor = torch.randn(1, 3, 256, 256)
target: Tensor = torch.randn(1, 3, 256, 256)
feature_extractor: nn.Module = FeatureExtractor([8, 12], "vgg16")
feature_extractor: nn.Module = FeatureExtractor2D([8, 12], "vgg16")
loss = StyleLoss(feature_extractor, input, randint)
result = loss(input=input, target=target, feature_extractor=True)
assert result is not None, "StyleLoss failed"
Expand All @@ -41,7 +41,7 @@ def test_style() -> None:
def test_perc() -> None:
input: Tensor = torch.randn(1, 3, 256, 256)
target: Tensor = torch.randn(1, 3, 256, 256)
feature: nn.Module = FeatureExtractor([8, 12], "vgg16")
feature: nn.Module = FeatureExtractor2D([8, 12], "vgg16")
loss = PerceptualLoss(feature, input, randint)
result = loss(input=input, target=target, feature_extractor=False)
assert result is not None, "PerceptualLoss failed"
Expand Down Expand Up @@ -377,7 +377,7 @@ def test_ffn(model_class, params) -> None:
out_features = params['out_features']
assert output.shape == (32, out_features)

def test_pos_embed() -> None:
def test_pos() -> None:
dropout = 0.1
batch_size = 32
seq_length = 10
Expand All @@ -398,4 +398,37 @@ def test_pos_embed() -> None:
output = abs_pos_enc(input_tensor)
assert output.shape == input_tensor.shape
output = dn_pos_enc(input_tensor)
assert output.shape == input_tensor.shape
assert output.shape == input_tensor.shape

def test_patch_embedding_3dcnn():
batch_size = 2
frames = 8
channels = 3
height = 32
width = 32
h_div = 4
w_div = 4
d_model = 64
architecture = (channels,)
hidden_activations = (nn.ReLU(),)
dropout = 0.1

input_tensor = torch.randn(batch_size, frames, channels, height, width)

feature_extractor = FeatureExtractor3D() # Define
pe = AbsoluteSinusoidalPositionalEncoding()

patch_embed = PatchEmbeddding3DCNN(h_div=h_div, w_div=w_div, pe=pe, feature_extractor=feature_extractor, X=input_tensor)

output = patch_embed(input_tensor)

assert output.shape == (batch_size, h_div * w_div, d_model)

feature_extractor = FeatureExtractor2D()

patch_embed = PatchEmbedding2DCNN(d_model=d_model, pe=pe, feature_extractor=feature_extractor, architecture=architecture, hidden_activations=hidden_activations, dropout=dropout)

output = patch_embed(input_tensor)

assert output.shape == (batch_size, frames, d_model)

0 comments on commit 797c017

Please sign in to comment.