-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
53 lines (44 loc) · 1.39 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
import argparse
from data_loader import get_loader
from args import get_parser
from utils import *
from DCGAN import Generator, Discriminator
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from tensorboardX import summary
from tensorboardX import FileWriter
os.environ['CUDA_VISIBLE_DEVICES'] = '0,2'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def main(opts):
opts.checkpoints = opts.checkpoints + 'DCGAN/'
#celeba_loader = get_loader(opts.image_dir, opts.dataset, batch_size=opts.batch_size, num_workers=opts.num_workers)
test(opts)
def test(opts):
netG = Generator()
netG = nn.DataParallel(netG).to(device)
#netD = nn.DataParallel(netD).to(device)
netG.load_state_dict(torch.load(opts.checkpoints+'Model/netG_epoch_20.pth'))
netG.eval()
count = 0
while(True):
noise = torch.randn(opts.batch_size, opts.nz, 1, 1, device=device)
fake_images = netG(noise)
for i in range(opts.batch_size):
vutils.save_image(fake_images.detach()[i],
'%s%04d.png' % ('Generated/0/', count), normalize=True)
count += 1
if(count > 30000):
break
print(count)
if __name__ == '__main__':
parse = get_parser()
opts = parse.parse_args()
main(opts)