-
Notifications
You must be signed in to change notification settings - Fork 1
/
CDCGAN.py
93 lines (87 loc) · 3.22 KB
/
CDCGAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import argparse
from args import get_parser
parse = get_parser()
opts = parse.parse_args()
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
#nz = int(opts.nz) + int(opts.num_label)
nz = int(opts.nz)
nlabel = int(opts.num_label)
ngf = int(opts.ngf)
self.noise = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d( nz, ngf * 4, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True)
)
self.label = nn.Sequential(
# input is label, going into a convolution
nn.ConvTranspose2d( nlabel, ngf * 4, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True)
)
self.main = nn.Sequential(
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d( ngf, 3, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, input, label):
first_noise = self.noise(input)
#print(first_noise.shape)
first_label = self.label(label)
#print(first_label.shape)
x = torch.cat([first_noise,first_label], 1)
output = self.main(x)
return output
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
ndf = int(opts.ndf)
nlabel = int(opts.num_label)
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(3+nlabel, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input, label):
x = torch.cat([input,label], 1)
output = self.main(x)
return output.view(-1, 1).squeeze(1)