-
Notifications
You must be signed in to change notification settings - Fork 0
/
small_decoder_model.py
44 lines (36 loc) · 1.63 KB
/
small_decoder_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, activation_function='leaky_relu'):
super(ConvBlock, self).__init__()
# Each ConvBlock doubles the spatial dimensions of the input, and reduces the
# number of channels (generally we halve the number of channels, but it can vary)
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
self.bn = nn.BatchNorm2d(num_features=out_channels)
if activation_function == 'leaky_relu':
self.activation_function = nn.LeakyReLU(negative_slope=0.02)
elif activation_function == 'tanh':
self.activation_function = nn.Tanh()
def forward(self, x):
x = F.interpolate(x, scale_factor=2, mode='bilinear')
x = self.bn(self.conv(x))
x = self.activation_function(x)
return x
class SmallDecoder(nn.Module):
def __init__(self):
super(SmallDecoder, self).__init__()
# This model combines several ConvBlocks such that an input
# of dimension 256x4x4 (channels = 256, h = 4, w = 4)
# ends up as 3x32x32 (channels = 3, h = 32, w = 32)
self.conv_block1 = ConvBlock(256, 128)
self.conv_block2 = ConvBlock(128, 64)
self.conv_block3 = ConvBlock(64, 32)
self.conv_block4 = ConvBlock(32, 3, 'tanh')
def forward(self, x):
x = self.conv_block1(x)
x = self.conv_block2(x)
x = self.conv_block3(x)
x = self.conv_block4(x)
return x