forked from hankook/AugSelf
-
Notifications
You must be signed in to change notification settings - Fork 0
/
resnets.py
93 lines (76 loc) · 2.89 KB
/
resnets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from typing import Any, Type, Union, List, Dict
import torch
from torch import Tensor, nn
from torch.hub import load_state_dict_from_url
from torchvision.models import ResNet
from torchvision.models.resnet import BasicBlock, Bottleneck, model_urls
from models import reset_parameters
class ResnetOutBlocks(ResNet):
def _forward_impl(self, x: Tensor) -> Dict[str, Tensor]:
# See note [TorchScript super()]
in_x=x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
conv1_out = x = self.maxpool(x)
l1 = x = self.layer1(x)
l2 = x = self.layer2(x)
l3 = x = self.layer3(x)
l4 = x = self.layer4(x)
x = self.avgpool(x)
backbone_out = x = torch.flatten(x, 1)
x = self.fc(x)
return dict(
input=in_x,
conv1=conv1_out,
l1=l1,
l2=l2,
l3=l3,
l4=l4,
backbone_out = backbone_out,
out=x
)
def _resnet(
arch: str,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
pretrained: bool,
progress: bool,
**kwargs: Any,
) -> ResnetOutBlocks:
model = ResnetOutBlocks(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
model.load_state_dict(state_dict)
return model
def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResnetOutBlocks:
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResnetOutBlocks:
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
def load_backbone_out_blocks(args):
name = args.model
if name == "resnet18":
backbone = resnet18(zero_init_residual=True)
elif name == "resnet50":
backbone = resnet50(zero_init_residual=True)
else:
raise NotImplementedError(name)
if name.startswith('cifar_'):
backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
backbone.maxpool = nn.Identity()
args.num_backbone_features = backbone.fc.weight.shape[1]
backbone.fc = nn.Identity()
reset_parameters(backbone)
return backbone