-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
64 lines (49 loc) · 2.5 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
class UpSample(nn.Sequential):
def __init__(self, skip_input, output_features):
super(UpSample, self).__init__()
self.convA = nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1)
self.leakyreluA = nn.LeakyReLU(0.2)
self.convB = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1)
self.leakyreluB = nn.LeakyReLU(0.2)
def forward(self, x, concat_with):
up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True)
return self.leakyreluB( self.convB( self.convA( torch.cat([up_x, concat_with], dim=1) ) ) )
class Decoder(nn.Module):
def __init__(self, num_features=1664, decoder_width = 1.0):
super(Decoder, self).__init__()
features = int(num_features * decoder_width)
self.conv2 = nn.Conv2d(num_features, features, kernel_size=1, stride=1, padding=0)
self.up1 = UpSample(skip_input=features//1 + 256, output_features=features//2)
self.up2 = UpSample(skip_input=features//2 + 128, output_features=features//4)
self.up3 = UpSample(skip_input=features//4 + 64, output_features=features//8)
self.up4 = UpSample(skip_input=features//8 + 64, output_features=features//16)
self.conv3 = nn.Conv2d(features//16, 1, kernel_size=3, stride=1, padding=1)
def forward(self, features):
x_block0, x_block1, x_block2, x_block3, x_block4 = features[3], features[4], features[6], features[8], features[12]
x_d0 = self.conv2(F.relu(x_block4))
x_d1 = self.up1(x_d0, x_block3)
x_d2 = self.up2(x_d1, x_block2)
x_d3 = self.up3(x_d2, x_block1)
x_d4 = self.up4(x_d3, x_block0)
return self.conv3(x_d4)
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.original_model = models.densenet169( pretrained=True )
for param in self.original_model.parameters():
param.requires_grad = False
def forward(self, x):
features = [x]
for k, v in self.original_model.features._modules.items(): features.append( v(features[-1]) )
return features
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
return self.decoder( self.encoder(x) )