-
Notifications
You must be signed in to change notification settings - Fork 181
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #62 from kristian-georgiev/master
add squeezenet models
- Loading branch information
Showing
2 changed files
with
148 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,4 @@ | |
from .vgg import * | ||
from .leaky_resnet import * | ||
from .alexnet import * | ||
from .squeezenet import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.init as init | ||
from torchvision.models.utils import load_state_dict_from_url | ||
from ..tools.custom_modules import FakeReLUM | ||
|
||
|
||
__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1'] | ||
|
||
model_urls = { | ||
'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth', | ||
'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth', | ||
} | ||
|
||
|
||
class Fire(nn.Module): | ||
|
||
def __init__(self, inplanes, squeeze_planes, | ||
expand1x1_planes, expand3x3_planes): | ||
super(Fire, self).__init__() | ||
self.inplanes = inplanes | ||
self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) | ||
self.squeeze_activation = nn.ReLU(inplace=True) | ||
self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, | ||
kernel_size=1) | ||
self.expand1x1_activation = nn.ReLU(inplace=True) | ||
self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, | ||
kernel_size=3, padding=1) | ||
self.expand3x3_activation = nn.ReLU(inplace=True) | ||
|
||
def forward(self, x): | ||
x = self.squeeze_activation(self.squeeze(x)) | ||
return torch.cat([ | ||
self.expand1x1_activation(self.expand1x1(x)), | ||
self.expand3x3_activation(self.expand3x3(x)) | ||
], 1) | ||
|
||
|
||
class SqueezeNet(nn.Module): | ||
|
||
def __init__(self, version='1_0', num_classes=1000): | ||
super(SqueezeNet, self).__init__() | ||
self.num_classes = num_classes | ||
if version == '1_0': | ||
self.features = nn.Sequential( | ||
nn.Conv2d(3, 96, kernel_size=7, stride=2), | ||
nn.ReLU(inplace=True), | ||
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), | ||
Fire(96, 16, 64, 64), | ||
Fire(128, 16, 64, 64), | ||
Fire(128, 32, 128, 128), | ||
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), | ||
Fire(256, 32, 128, 128), | ||
Fire(256, 48, 192, 192), | ||
Fire(384, 48, 192, 192), | ||
Fire(384, 64, 256, 256), | ||
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), | ||
Fire(512, 64, 256, 256), | ||
) | ||
elif version == '1_1': | ||
self.features = nn.Sequential( | ||
nn.Conv2d(3, 64, kernel_size=3, stride=2), | ||
nn.ReLU(inplace=True), | ||
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), | ||
Fire(64, 16, 64, 64), | ||
Fire(128, 16, 64, 64), | ||
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), | ||
Fire(128, 32, 128, 128), | ||
Fire(256, 32, 128, 128), | ||
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), | ||
Fire(256, 48, 192, 192), | ||
Fire(384, 48, 192, 192), | ||
Fire(384, 64, 256, 256), | ||
Fire(512, 64, 256, 256), | ||
) | ||
else: | ||
raise ValueError("Unsupported SqueezeNet version {version}:" | ||
"1_0 or 1_1 expected".format(version=version)) | ||
|
||
# Final convolution is initialized differently from the rest | ||
final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1) | ||
self.classifier = nn.Sequential( | ||
nn.Dropout(p=0.5), | ||
final_conv, | ||
nn.ReLU(inplace=True), | ||
nn.AdaptiveAvgPool2d((1, 1)) | ||
) | ||
self.last_relu = nn.ReLU(inplace=True) | ||
self.last_relu_fake = FakeReLUM() | ||
|
||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
if m is final_conv: | ||
init.normal_(m.weight, mean=0.0, std=0.01) | ||
else: | ||
init.kaiming_uniform_(m.weight) | ||
if m.bias is not None: | ||
init.constant_(m.bias, 0) | ||
|
||
def forward(self, x, with_latent=False, fake_relu=False, no_relu=False): | ||
x = self.features(x) | ||
x_latent = self.classifier[:2](x) | ||
x_relu = self.last_relu(x_latent) if not fake_relu else self.classifier_last_relu_fake(x_latent) | ||
x_out = self.classifier[-1:](x_relu) | ||
x_out = torch.flatten(x_out, 1) | ||
|
||
if with_latent and no_relu: | ||
return x_out, x_latent # potentially will need to flatten x_latent | ||
if with_latent: | ||
return x_out, x_relu # potentially will need to flatten x_relu | ||
return x_out | ||
|
||
|
||
def _squeezenet(version, pretrained, progress, **kwargs): | ||
model = SqueezeNet(version, **kwargs) | ||
if pretrained: | ||
arch = 'squeezenet' + version | ||
state_dict = load_state_dict_from_url(model_urls[arch], | ||
progress=progress) | ||
model.load_state_dict(state_dict) | ||
return model | ||
|
||
|
||
def squeezenet1_0(pretrained=False, progress=True, **kwargs): | ||
r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level | ||
accuracy with 50x fewer parameters and <0.5MB model size" | ||
<https://arxiv.org/abs/1602.07360>`_ paper. | ||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
""" | ||
return _squeezenet('1_0', pretrained, progress, **kwargs) | ||
|
||
|
||
|
||
def squeezenet1_1(pretrained=False, progress=True, **kwargs): | ||
r"""SqueezeNet 1.1 model from the `official SqueezeNet repo | ||
<https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_. | ||
SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters | ||
than SqueezeNet 1.0, without sacrificing accuracy. | ||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
""" | ||
return _squeezenet('1_1', pretrained, progress, **kwargs) |