-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
115 lines (97 loc) · 3.54 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import sys
from options.train_option import TrainOptions
from call_methods import make_model, make_dataset
from utils import tb_visualizer
from utils.utils import set_seed
import time
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
def run() -> None:
"""
Run the training process
Parameters
----------
None
Returns
-------
None
Process
-------
1. Parse the training options
2. Set the random seed
3. Create the model
4. Create the dataset
5. Create the visualizer
6. Train the model
7. Save the model
"""
opt = TrainOptions().parse()
# Set seed
set_seed(opt.seed)
model = make_model(opt.model_name, opt)
dataset = make_dataset(dataset_name=opt.dataset_name, opt=opt)
if isinstance(dataset, tuple) and len(dataset) == 2:
train_dataset, test_dataset = dataset
else:
train_dataset, test_dataset = dataset[0], None
visualizer = tb_visualizer.Visualizer(opt)
start = time.time()
epoch = 0
for epoch in range(opt.n_epochs):
train_epoch_start = time.time()
for i, data in enumerate(train_dataset.dataloader):
model.set_input(data)
train_generator = i % opt.train_dis_freq == 0
total_steps = epoch * len(train_dataset.dataloader) + i
do_visualization = total_steps % opt.save_image_frequency == 0
model.train(
train_generator=train_generator, do_visualization=do_visualization
)
if train_generator:
visualizer.log_performance(
model.performance,
epoch=epoch,
step=i,
total_steps=total_steps,
is_train=True,
print_freq=opt.print_freq,
)
if do_visualization:
visualizer.log_image(
model.vis_data, total_steps=total_steps, is_train=True
)
train_epoch_end = time.time()
visualizer.log_time(train_epoch_end, train_epoch_start, epoch, is_train=True)
if test_dataset is not None:
test_epoch_start = time.time()
for i, data in enumerate(test_dataset.dataloader):
model.set_input(data)
total_steps = epoch * len(test_dataset.dataloader) + i
do_visualization = (
total_steps % opt.print_freq == 0
or total_steps % opt.save_image_frequency == 0
)
model.test(do_visualization=do_visualization)
visualizer.log_performance(
model.performance,
epoch=epoch,
step=i,
total_steps=total_steps,
is_train=False,
print_freq=opt.print_freq,
)
if do_visualization:
visualizer.log_image(
model.vis_data, total_steps=total_steps, is_train=False
)
test_epoch_end = time.time()
visualizer.log_time(test_epoch_end, test_epoch_start, epoch, is_train=False)
if epoch >= opt.lr_decay_start:
model.update_learning_rate()
if epoch % opt.model_save_frequency == 0 or epoch == opt.n_epochs - 1:
model.save_networks(visualizer.log_dir, epoch)
end = time.time()
visualizer.log_time(end, start, epoch, training_end=True)
visualizer.close()
if __name__ == "__main__":
run()