-
Notifications
You must be signed in to change notification settings - Fork 2
/
generator.py
179 lines (159 loc) · 6.33 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# importing libraries
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchaudio
import torchvision
import os
from torchvision import transforms, datasets
import numpy as np
import shutil
from torch.utils.data import DataLoader, Subset
import torch.optim as optim
import time
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import StepLR
class ConvolutionalBlock(nn.Module):
def __init__( # initializes the values for the conv layers
self,
in_channels: int,
out_channels: int,
is_downsampling: bool = True,
add_activation: bool = True,
**kwargs
):
super().__init__()
if is_downsampling: # cnn layer for downsampling
self.conv = nn.Sequential( # sequential contains cnn, normalization, activation function
nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs),
nn.InstanceNorm2d(out_channels),
nn.ReLU(inplace=True) if add_activation else nn.Identity(),
)
else: # otherwise we are upsampling
self.conv = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
nn.InstanceNorm2d(out_channels),
nn.ReLU(inplace=True) if add_activation else nn.Identity(),
)
def forward(self, x):
return self.conv(x)
class ResidualBlock(nn.Module):
def __init__(self, channels: int):
"""
In a residual block, the use of two ConvBlock instances with one having
an activation function and the other not is a design choice that promotes
the learning of residual information.
The purpose of a residual block is to learn the residual mapping between
the input and output of the block. The first ConvBlock in the sequence,
which includes an activation function, helps in capturing and extracting
important features from the input. The activation function introduces
non-linearity, allowing the network to model complex relationships
between the input and output.
The second ConvBlock does not include an activation function.
It mainly focuses on adjusting the dimensions (e.g., number of channels)
of the features extracted by the first ConvBlock. The absence of an
activation function in the second ConvBlock allows the block to learn
the residual information. By directly adding the output of the second
ConvBlock to the original input, the block learns to capture the
residual features or changes needed to reach the desired output.
(Information and explanation above generated by ChatGPT)
"""
super().__init__()
self.block = nn.Sequential(
ConvolutionalBlock(channels, channels, add_activation=True, kernel_size=3, padding=1),
ConvolutionalBlock(channels, channels, add_activation=False, kernel_size=3, padding=1),
)
def forward(self, x):
"""
This skip connection, achieved through the addition operation, helps
in propagating gradients during training and alleviates the vanishing
gradient problem. It also facilitates the flow of information from earlier
layers to later layers, allowing the network to learn more effectively.
(Information and explanation above generated by ChatGPT)
"""
return x + self.block(x)
class Generator(nn.Module):
def __init__(
self, img_channels: int, num_features: int = 64, num_residuals: int = 9):
"""
Generator consists of 2 layers of downsampling/encoding layer,
followed by 9 residual blocks for 128 × 128 training images
and then 3 upsampling/decoding layer.
The network with 6 residual blocks can be written as:
c7s1–64, d128, d256, R256, R256, R256, R256, R256, R256, u128, u64, and c7s1–3.
The network with 9 residual blocks consists of:
c7s1–64, d128, d256, R256, R256, R256, R256, R256, R256, R256, R256, R256, u128, u64, and c7s1–3.
"""
super().__init__()
self.initial_layer = nn.Sequential(
nn.Conv2d(
img_channels,
num_features,
kernel_size=7,
stride=1,
padding=3,
padding_mode="reflect",
),
nn.ReLU(inplace=True),
)
self.downsampling_layers = nn.ModuleList(
[
ConvolutionalBlock(
num_features,
num_features * 2,
is_downsampling=True,
kernel_size=3,
stride=2,
padding=1,
),
ConvolutionalBlock(
num_features * 2,
num_features * 4,
is_downsampling=True,
kernel_size=3,
stride=2,
padding=1,
),
]
)
self.residual_layers = nn.Sequential(
*[ResidualBlock(num_features * 4) for _ in range(num_residuals)]
)
self.upsampling_layers = nn.ModuleList(
[
ConvolutionalBlock(
num_features * 4,
num_features * 2,
is_downsampling=False,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
),
ConvolutionalBlock(
num_features * 2,
num_features * 1,
is_downsampling=False,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
),
]
)
self.last_layer = nn.Conv2d(
num_features * 1,
img_channels,
kernel_size=7,
stride=1,
padding=3,
padding_mode="reflect",
)
def forward(self, x):
x = self.initial_layer(x)
for layer in self.downsampling_layers:
x = layer(x)
x = self.residual_layers(x)
for layer in self.upsampling_layers:
x = layer(x)
return torch.tanh(self.last_layer(x))