-
-
Notifications
You must be signed in to change notification settings - Fork 50
/
Xception.py
159 lines (127 loc) · 5.25 KB
/
Xception.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""
Xception
The network uses a modified version of Depthwise Seperable Convolution. It combines
ideas from MobileNetV1 like depthwise seperable conv and from InceptionV3, the order
of the layers like conv1x1 and then spatial kernels.
In modified Depthwise Seperable Convolution network, the order of operation is changed
by keeping Conv1x1 and then the spatial convolutional kernel. And the other difference
is the absence of Non-Linear activation function. And with inclusion of residual
connections impacts the performs of Xception widely.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SeparableConv(nn.Module):
def __init__(self, input_channel, output_channel, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
super().__init__()
self.dwc = nn.Sequential(
nn.Conv2d(input_channel, input_channel, kernel_size, stride, padding, dilation, groups=input_channel, bias=bias),
nn.Conv2d(input_channel, output_channel, 1, 1, 0, 1, 1, bias=bias)
)
def forward(self, X):
return self.dwc(X)
class Block(nn.Module):
def __init__(self, input_channel, out_channel, reps, strides=1, relu=True, grow_first=True):
super().__init__()
if out_channel != input_channel or strides!=1:
self.skipConnection = nn.Sequential(
nn.Conv2d(input_channel, out_channel, 1, stride=strides, bias=False),
nn.BatchNorm2d(out_channel)
)
else:
self.skipConnection = None
self.relu = nn.ReLU(inplace=True)
rep = []
filters = input_channel
if grow_first:
rep.append(self.relu)
rep.append(SeparableConv(input_channel, out_channel, 3, stride=1, padding=1, bias=False))
rep.append(nn.BatchNorm2d(out_channel))
filters = out_channel
for _ in range(reps-1):
rep.append(self.relu)
rep.append(SeparableConv(filters, filters, 3, stride=1, padding=1, bias=False))
rep.append(nn.BatchNorm2d(filters))
if not grow_first:
rep.append(self.relu)
rep.append(SeparableConv(input_channel, out_channel, 3, stride=1, padding=1, bias=False))
rep.append(nn.BatchNorm2d(out_channel))
if not relu:
rep = rep[1:]
else:
rep[0] = nn.ReLU(inplace=False)
if strides != 1:
rep.append(nn.MaxPool2d(3, strides, 1))
self.rep = nn.Sequential(*rep)
def forward(self, input):
X = self.rep(input)
if self.skipConnection:
skip = self.skipConnection(input)
else:
skip = input
X += skip
return X
class Xception(nn.Module):
def __init__(self, input_channel, n_classes):
super().__init__()
self.n_classes = n_classes
self.relu = nn.ReLU(inplace=True)
self.initBlock = nn.Sequential(
nn.Conv2d(input_channel, 32, 3, 2, 1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size = 3, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.block1=Block(64,128,2,2,relu=False,grow_first=True)
self.block2=Block(128,256,2,2,relu=True,grow_first=True)
self.block3=Block(256,728,2,2,relu=True,grow_first=True)
self.block4=Block(728,728,3,1,relu=True,grow_first=True)
self.block5=Block(728,728,3,1,relu=True,grow_first=True)
self.block6=Block(728,728,3,1,relu=True,grow_first=True)
self.block7=Block(728,728,3,1,relu=True,grow_first=True)
self.block8=Block(728,728,3,1,relu=True,grow_first=True)
self.block9=Block(728,728,3,1,relu=True,grow_first=True)
self.block10=Block(728,728,3,1,relu=True,grow_first=True)
self.block11=Block(728,728,3,1,relu=True,grow_first=True)
self.block12=Block(728,1024,2,2,relu=True,grow_first=False)
self.conv3 = SeparableConv(1024,1536,3,1,1)
self.bn3 = nn.BatchNorm2d(1536)
#do relu here
self.conv4 = SeparableConv(1536,2048,3,1,1)
self.bn4 = nn.BatchNorm2d(2048)
self.fc = nn.Linear(2048, self.n_classes)
#weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, x):
x = self.initBlock(x)
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = self.block5(x)
x = self.block6(x)
x = self.block7(x)
x = self.block8(x)
x = self.block9(x)
x = self.block10(x)
x = self.block11(x)
x = self.block12(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu(x)
x = self.conv4(x)
x = self.bn4(x)
x = self.relu(x)
x = F.adaptive_avg_pool2d(x, (1, 1))
x = x.view(x.size(0), -1)
x = self.fc(x)
return x