-
Notifications
You must be signed in to change notification settings - Fork 57
/
xresnet.py
165 lines (116 loc) · 5.86 KB
/
xresnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# adapted from https://github.com/fastai/fastai/blob/master/fastai/vision/models/xresnet.py
# added simple self attention layer, conv1d
# ideally conv1d should have been available from fast.ai's layers.py/ or SSA layer should be in layers.py
# added sa option to XResNet class
# added sa to ResBlock class, xresnet function
# added import fastai.torch_core for spectral norm and tensor
from fastai.torch_core import *
import torch.nn as nn
import torch,math,sys
import torch.utils.model_zoo as model_zoo
from functools import partial
__all__ = ['XResNet', 'xresnet18', 'xresnet34', 'xresnet50', 'xresnet101', 'xresnet152']
# or: ELU+init (a=0.54; gain=1.55)
act_fn = nn.ReLU(inplace=True)
#Unmodified from https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py
def conv1d(ni:int, no:int, ks:int=1, stride:int=1, padding:int=0, bias:bool=False):
"Create and initialize a `nn.Conv1d` layer with spectral normalization."
conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
nn.init.kaiming_normal_(conv.weight)
if bias: conv.bias.data.zero_()
return spectral_norm(conv)
# Adapted from SelfAttention layer at https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py
# Inspired by https://arxiv.org/pdf/1805.08318.pdf
class SimpleSelfAttention(nn.Module):
def __init__(self, n_in:int, ks=1, sym=False):#, n_out:int):
super().__init__()
self.conv = conv1d(n_in, n_in, ks, padding=ks//2, bias=False)
self.gamma = nn.Parameter(tensor([0.]))
self.sym = sym
self.n_in = n_in
def forward(self,x):
if self.sym:
# symmetry hack by https://github.com/mgrankin
c = self.conv.weight.view(self.n_in,self.n_in)
c = (c + c.t())/2
self.conv.weight = c.view(self.n_in,self.n_in,1)
size = x.size()
x = x.view(*size[:2],-1) # (C,N)
# changed the order of mutiplication to avoid O(N^2) complexity
# (x*xT)*(W*x) instead of (x*(xT*(W*x)))
convx = self.conv(x) # (C,C) * (C,N) = (C,N) => O(NC^2)
xxT = torch.bmm(x,x.permute(0,2,1).contiguous()) # (C,N) * (N,C) = (C,C) => O(NC^2)
o = torch.bmm(xxT, convx) # (C,C) * (C,N) = (C,N) => O(NC^2)
o = self.gamma * o + x
return o.view(*size).contiguous()
class Flatten(nn.Module):
def forward(self, x): return x.view(x.size(0), -1)
def init_cnn(m):
if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)
if isinstance(m, (nn.Conv2d,nn.Linear)): nn.init.kaiming_normal_(m.weight)
for l in m.children(): init_cnn(l)
def conv(ni, nf, ks=3, stride=1, bias=False):
return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=bias)
def noop(x): return x
def conv_layer(ni, nf, ks=3, stride=1, zero_bn=False, act=True):
bn = nn.BatchNorm2d(nf)
nn.init.constant_(bn.weight, 0. if zero_bn else 1.)
layers = [conv(ni, nf, ks, stride=stride), bn]
if act: layers.append(act_fn)
return nn.Sequential(*layers)
class ResBlock(nn.Module):
def __init__(self, expansion, ni, nh, stride=1,sa=False, sym=False):
super().__init__()
nf,ni = nh*expansion,ni*expansion
layers = [conv_layer(ni, nh, 3, stride=stride),
conv_layer(nh, nf, 3, zero_bn=True, act=False)
] if expansion == 1 else [
conv_layer(ni, nh, 1),
conv_layer(nh, nh, 3, stride=stride),
conv_layer(nh, nf, 1, zero_bn=True, act=False)
]
self.sa = SimpleSelfAttention(nf,ks=1,sym=sym) if sa else noop
self.convs = nn.Sequential(*layers)
# TODO: check whether act=True works better
self.idconv = noop if ni==nf else conv_layer(ni, nf, 1, act=False)
self.pool = noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)
def forward(self, x): return act_fn(self.sa(self.convs(x)) + self.idconv(self.pool(x)))
def filt_sz(recep): return min(64, 2**math.floor(math.log2(recep*0.75)))
class XResNet(nn.Sequential):
def __init__(self, expansion, layers, c_in=3, c_out=1000, sa = False, sym= False):
stem = []
sizes = [c_in,32,32,64]
for i in range(3):
stem.append(conv_layer(sizes[i], sizes[i+1], stride=2 if i==0 else 1))
#nf = filt_sz(c_in*9)
#stem.append(conv_layer(c_in, nf, stride=2 if i==1 else 1))
#c_in = nf
block_szs = [64//expansion,64,128,256,512]
blocks = [self._make_layer(expansion, block_szs[i], block_szs[i+1], l, 1 if i==0 else 2, sa = sa if i in[len(layers)-4] else False, sym=sym)
for i,l in enumerate(layers)]
super().__init__(
*stem,
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
*blocks,
nn.AdaptiveAvgPool2d(1), Flatten(),
nn.Linear(block_szs[-1]*expansion, c_out),
)
init_cnn(self)
def _make_layer(self, expansion, ni, nf, blocks, stride, sa=False, sym=False):
return nn.Sequential(
*[ResBlock(expansion, ni if i==0 else nf, nf, stride if i==0 else 1, sa if i in [blocks -1] else False,sym)
for i in range(blocks)])
def xresnet(expansion, n_layers, name, pretrained=False, **kwargs):
model = XResNet(expansion, n_layers, **kwargs)
if pretrained: model.load_state_dict(model_zoo.load_url(model_urls[name]))
return model
me = sys.modules[__name__]
for n,e,l in [
[ 18 , 1, [2,2,2 ,2] ],
[ 34 , 1, [3,4,6 ,3] ],
[ 50 , 4, [3,4,6 ,3] ],
[ 101, 4, [3,4,23,3] ],
[ 152, 4, [3,8,36,3] ],
]:
name = f'xresnet{n}'
setattr(me, name, partial(xresnet, expansion=e, n_layers=l, name=name))