-
Notifications
You must be signed in to change notification settings - Fork 1
/
common.py
124 lines (93 loc) · 3.45 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import torch.nn as nn
import numpy as np
from downsampler import Downsampler
def add_module(self, module):
self.add_module(str(len(self) + 1), module)
torch.nn.Module.add = add_module
class Concat(nn.Module):
def __init__(self, dim, *args):
super(Concat, self).__init__()
self.dim = dim
for idx, module in enumerate(args):
self.add_module(str(idx), module)
def forward(self, input):
inputs = []
for module in self._modules.values():
inputs.append(module(input))
inputs_shapes2 = [x.shape[2] for x in inputs]
inputs_shapes3 = [x.shape[3] for x in inputs]
if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(np.array(inputs_shapes3) == min(inputs_shapes3)):
inputs_ = inputs
else:
target_shape2 = min(inputs_shapes2)
target_shape3 = min(inputs_shapes3)
inputs_ = []
for inp in inputs:
diff2 = (inp.size(2) - target_shape2) // 2
diff3 = (inp.size(3) - target_shape3) // 2
inputs_.append(inp[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3])
return torch.cat(inputs_, dim=self.dim)
def __len__(self):
return len(self._modules)
class GenNoise(nn.Module):
def __init__(self, dim2):
super(GenNoise, self).__init__()
self.dim2 = dim2
def forward(self, input):
a = list(input.size())
a[1] = self.dim2
# print (input.data.type())
b = torch.zeros(a).type_as(input.data)
b.normal_()
x = torch.autograd.Variable(b)
return x
class Swish(nn.Module):
"""
https://arxiv.org/abs/1710.05941
The hype was so huge that I could not help but try it
"""
def __init__(self):
super(Swish, self).__init__()
self.s = nn.Sigmoid()
def forward(self, x):
return x * self.s(x)
def act(act_fun = 'LeakyReLU'):
'''
Either string defining an activation function or module (e.g. nn.ReLU)
'''
if isinstance(act_fun, str):
if act_fun == 'LeakyReLU':
return nn.LeakyReLU(0.2, inplace=True)
elif act_fun == 'Swish':
return Swish()
elif act_fun == 'ELU':
return nn.ELU()
elif act_fun == 'none':
return nn.Sequential()
else:
assert False
else:
return act_fun()
def bn(num_features):
return nn.BatchNorm2d(num_features)
def conv(in_f, out_f, kernel_size, stride=1, bias=True, pad='zero', downsample_mode='stride'):
downsampler = None
if stride != 1 and downsample_mode != 'stride':
if downsample_mode == 'avg':
downsampler = nn.AvgPool2d(stride, stride)
elif downsample_mode == 'max':
downsampler = nn.MaxPool2d(stride, stride)
elif downsample_mode in ['lanczos2', 'lanczos3']:
downsampler = Downsampler(n_planes=out_f, factor=stride, kernel_type=downsample_mode, phase=0.5, preserve_size=True)
else:
assert False
stride = 1
padder = None
to_pad = int((kernel_size - 1) / 2)
if pad == 'reflection':
padder = nn.ReflectionPad2d(to_pad)
to_pad = 0
convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias)
layers = filter(lambda x: x is not None, [padder, convolver, downsampler])
return nn.Sequential(*layers)