-
Notifications
You must be signed in to change notification settings - Fork 2
/
post_training.py
39 lines (31 loc) · 1.51 KB
/
post_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
def epoch_sweep(args, vangan_model, plotter, test_path='', start=100, end=200, step=2, segmentation=True):
"""
Perform a sweep of epochs for the given VANGAN model and save the resulting images using the given plotter.
Args:
- args: command-line arguments
- vangan_model: a VANGAN object
- plotter: a GANMonitor object
- test_path (str): path to the directory containing the test images
- start (int): the starting epoch number (inclusive)
- end (int): the ending epoch number (inclusive)
- step (int): the number of epochs to skip between each saved image
- segmentation (bool): if True, generate segmentation images; otherwise, generate fake imaging domain images
Returns:
- None
"""
for i in range(start, end + 1, step):
print(i)
vangan_model.load_checkpoint(epoch=i,
newpath=args.output_dir+'/checkpoints')
# Make epoch folders
filepath = args.output_dir+'/Epoch_Sampling/'
folder = os.path.join(filepath, 'e{idx}'.format(idx=i))
if not os.path.isdir(folder):
os.makedirs(folder)
testfiles = os.listdir(test_path)
filename = 'e{idx}_VG_'.format(idx=i)
for file in testfiles:
testfiles[file] = os.path.join(test_path, file)
plotter.run_mapping(vangan_model, testfiles, args.INPUT_IMG_SIZE, filetext=filename, segmentation=segmentation,
stride=(50, 50, 50), filepath=folder, padFactor=0.1)