-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
redefining fourier and test automation
- Loading branch information
1 parent
db29944
commit d172c33
Showing
7 changed files
with
305 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,10 @@ | ||
from .fourier import * | ||
from .transformer.attention import * | ||
from .transformer import * | ||
from .partial import * | ||
from .normalization import * | ||
from .criterions import * | ||
from .complex import * | ||
from .dnn import * | ||
from .kan import * | ||
from .monte_carlo import * | ||
from .utils import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from torch import nn, Tensor | ||
from typing import List, Sequence | ||
from torchvision.models import ( | ||
VGG19_Weights, vgg19, | ||
VGG16_Weights, vgg16, | ||
ResNet50_Weights, resnet50, | ||
resnet101, ResNet101_Weights | ||
) | ||
|
||
|
||
VALID_MODELS = { | ||
'vgg19': { | ||
'model': vgg19, | ||
'weights': VGG19_Weights, | ||
'valid_layers': [ | ||
i for i in range(37) | ||
] | ||
}, | ||
'vgg16': { | ||
'model': vgg16, | ||
'weights': VGG16_Weights, | ||
'valid_layers': [ | ||
i for i in range(31) | ||
] | ||
}, | ||
'resnet50': { | ||
'model': resnet50, | ||
'weights': ResNet50_Weights, | ||
'valid_layers': [ | ||
'conv1', | ||
'bn1', | ||
'relu', | ||
'maxpool', | ||
'layer1', | ||
'layer2', | ||
'layer3', | ||
'layer4', | ||
'avgpool', | ||
'fc', | ||
] | ||
}, | ||
'resnet101': { | ||
'model': resnet101, | ||
'weights': ResNet101_Weights, | ||
'valid_layers': [ | ||
'conv1', | ||
'bn1', | ||
'relu', | ||
'maxpool', | ||
'layer1', | ||
'layer2', | ||
'layer3', | ||
'layer4', | ||
'avgpool', | ||
'fc', | ||
] | ||
}, | ||
} | ||
|
||
|
||
class FeatureExtractor(nn.Module): | ||
def __init__(self, layers: Sequence[int] = [4,9,18], model_str: str = 'vgg19') -> None: | ||
assert(model_str in VALID_MODELS), f'Model not in {VALID_MODELS.keys()}' | ||
assert (list(set(layers)) == layers), 'Not valid repeated inputs' | ||
hist: List = [] | ||
for layer in layers: | ||
valid_models: List[str] = VALID_MODELS[model_str]['valid_layers'] | ||
num = valid_models.index(layer) | ||
hist.append(num) | ||
assert (sorted(hist) == hist), 'Not ordered inputs' | ||
super().__init__() | ||
self.model_str: str = model_str | ||
self.layers = list(map(str, layers)) | ||
self.model = VALID_MODELS[model_str](weights = VALID_MODELS[model_str][1].IMAGENET1K_V1) | ||
for param in self.model.parameters(): | ||
param.requires_grad = False | ||
# Setting the transformation | ||
self.transform = VALID_MODELS[model_str][1].IMAGENET1K_V1.transforms(antialias=True) | ||
|
||
def forward(self, input: Tensor) -> List[Tensor]: | ||
features = [] | ||
|
||
if 'vgg' in self.model_str: | ||
for name, layer in self.model.features.named_children(): | ||
input = layer(input) | ||
if name in self.layers: | ||
features.append(input) | ||
if name == self.layers[-1]: | ||
return features | ||
else: | ||
for name, layer in self.model.named_children(): | ||
input = layer(input) | ||
if name in self.layers: | ||
features.append(input) | ||
if name == self.layers[-1]: | ||
return features |
Oops, something went wrong.