-
Notifications
You must be signed in to change notification settings - Fork 1
/
unet.py
149 lines (94 loc) · 3.96 KB
/
unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNet(nn.Module):
def __init__(self, in_channels, out_channels, base_channels = 64):
super().__init__()
self.conv_in = nn.Conv2d(in_channels, base_channels, kernel_size=1)
self.down1 = DownStep(base_channels, base_channels * 2)
self.down2 = DownStep(base_channels * 2, base_channels * 4)
self.down3 = DownStep(base_channels * 4, base_channels * 8)
self.down4 = DownStep(base_channels * 8, base_channels * 16)
self.up1 = UpStep(base_channels * 16, base_channels * 8)
self.up2 = UpStep(base_channels * 8, base_channels * 4)
self.up3 = UpStep(base_channels * 4, base_channels * 2)
self.up4 = UpStep(base_channels * 2, base_channels)
self.conv_out = nn.Conv2d(base_channels, out_channels, kernel_size=1)
def forward(self, x):
x1 = self.conv_in(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x = self.down4(x4)
x = self.up1(x, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.conv_out(x)
return x
class DownStep(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.conv = ConvBlock(in_channels, out_channels)
self.pool = PoolBlock(out_channels)
def forward(self, x):
x = self.conv(x)
x = self.pool(x)
return x
class UpStep(nn.Module):
def __init__(self, in_channels: int, out_channels: int = None):
super().__init__()
out_channels = in_channels // 2 # standard UNet configuration
self.up = UnpoolBlock(in_channels)
self.conv = ConvBlock(int(1.5*in_channels), out_channels) # recurrent tensor is concatenated, so that the number of in_channels is increased 1.5-fold
def forward(self, current_x, recurrent_x):
current_x = self.up(current_x)
# current_x max be smaller than recurrent_x; therefore current_x must be padded before concatenation
difference_dim2 = recurrent_x.size()[2] - current_x.size()[2] # difference in height
difference_dim3 = recurrent_x.size()[3] - current_x.size()[3] # difference in width
current_x = F.pad(current_x, [
difference_dim3 // 2, difference_dim3 - (difference_dim3 // 2), # left side and right side
difference_dim2 // 2, difference_dim2 - (difference_dim2 // 2) # top and bottom
])
x = torch.cat([recurrent_x, current_x], dim=1)
x = self.conv(x)
return x
class ConvBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.layers = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.LeakyReLU(),
# Select one
nn.BatchNorm2d(out_channels),
# nn.InstanceNorm2d(out_channels)
nn.Identity()
)
def forward(self, x):
return self.layers(x)
class PoolBlock(nn.Module):
def __init__(self, n_channels):
super().__init__()
self.layers = nn.Sequential(
nn.Conv2d(n_channels, n_channels, kernel_size=2, stride=2),
nn.LeakyReLU(),
# Select one
nn.BatchNorm2d(n_channels),
# nn.InstanceNorm2d(out_channels)
nn.Identity()
)
def forward(self, x):
return self.layers(x)
class UnpoolBlock(nn.Module):
def __init__(self, n_channels):
super().__init__()
self.layers = nn.Sequential(
nn.ConvTranspose2d(n_channels, n_channels, kernel_size=2, stride=2),
nn.LeakyReLU(),
# Select one
nn.BatchNorm2d(n_channels),
# nn.InstanceNorm2d(out_channels)
nn.Identity()
)
def forward(self, x):
return self.layers(x)