diff --git a/robustness/imagenet_models/__init__.py b/robustness/imagenet_models/__init__.py index 5668e1b..255b84d 100644 --- a/robustness/imagenet_models/__init__.py +++ b/robustness/imagenet_models/__init__.py @@ -3,3 +3,4 @@ from .vgg import * from .leaky_resnet import * from .alexnet import * +from .squeezenet import * diff --git a/robustness/imagenet_models/squeezenet.py b/robustness/imagenet_models/squeezenet.py new file mode 100644 index 0000000..a5034f1 --- /dev/null +++ b/robustness/imagenet_models/squeezenet.py @@ -0,0 +1,147 @@ +import torch +import torch.nn as nn +import torch.nn.init as init +from torchvision.models.utils import load_state_dict_from_url +from ..tools.custom_modules import FakeReLUM + + +__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1'] + +model_urls = { + 'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth', + 'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth', +} + + +class Fire(nn.Module): + + def __init__(self, inplanes, squeeze_planes, + expand1x1_planes, expand3x3_planes): + super(Fire, self).__init__() + self.inplanes = inplanes + self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) + self.squeeze_activation = nn.ReLU(inplace=True) + self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, + kernel_size=1) + self.expand1x1_activation = nn.ReLU(inplace=True) + self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, + kernel_size=3, padding=1) + self.expand3x3_activation = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.squeeze_activation(self.squeeze(x)) + return torch.cat([ + self.expand1x1_activation(self.expand1x1(x)), + self.expand3x3_activation(self.expand3x3(x)) + ], 1) + + +class SqueezeNet(nn.Module): + + def __init__(self, version='1_0', num_classes=1000): + super(SqueezeNet, self).__init__() + self.num_classes = num_classes + if version == '1_0': + self.features = nn.Sequential( + nn.Conv2d(3, 96, kernel_size=7, stride=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), + Fire(96, 16, 64, 64), + Fire(128, 16, 64, 64), + Fire(128, 32, 128, 128), + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), + Fire(256, 32, 128, 128), + Fire(256, 48, 192, 192), + Fire(384, 48, 192, 192), + Fire(384, 64, 256, 256), + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), + Fire(512, 64, 256, 256), + ) + elif version == '1_1': + self.features = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=3, stride=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), + Fire(64, 16, 64, 64), + Fire(128, 16, 64, 64), + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), + Fire(128, 32, 128, 128), + Fire(256, 32, 128, 128), + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), + Fire(256, 48, 192, 192), + Fire(384, 48, 192, 192), + Fire(384, 64, 256, 256), + Fire(512, 64, 256, 256), + ) + else: + raise ValueError("Unsupported SqueezeNet version {version}:" + "1_0 or 1_1 expected".format(version=version)) + + # Final convolution is initialized differently from the rest + final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1) + self.classifier = nn.Sequential( + nn.Dropout(p=0.5), + final_conv, + nn.ReLU(inplace=True), + nn.AdaptiveAvgPool2d((1, 1)) + ) + self.last_relu = nn.ReLU(inplace=True) + self.last_relu_fake = FakeReLUM() + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if m is final_conv: + init.normal_(m.weight, mean=0.0, std=0.01) + else: + init.kaiming_uniform_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + + def forward(self, x, with_latent=False, fake_relu=False, no_relu=False): + x = self.features(x) + x_latent = self.classifier[:2](x) + x_relu = self.last_relu(x_latent) if not fake_relu else self.classifier_last_relu_fake(x_latent) + x_out = self.classifier[-1:](x_relu) + x_out = torch.flatten(x_out, 1) + + if with_latent and no_relu: + return x_out, x_latent # potentially will need to flatten x_latent + if with_latent: + return x_out, x_relu # potentially will need to flatten x_relu + return x_out + + +def _squeezenet(version, pretrained, progress, **kwargs): + model = SqueezeNet(version, **kwargs) + if pretrained: + arch = 'squeezenet' + version + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def squeezenet1_0(pretrained=False, progress=True, **kwargs): + r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level + accuracy with 50x fewer parameters and <0.5MB model size" + `_ paper. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _squeezenet('1_0', pretrained, progress, **kwargs) + + + +def squeezenet1_1(pretrained=False, progress=True, **kwargs): + r"""SqueezeNet 1.1 model from the `official SqueezeNet repo + `_. + SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters + than SqueezeNet 1.0, without sacrificing accuracy. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _squeezenet('1_1', pretrained, progress, **kwargs)