-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain18.py
113 lines (90 loc) · 3.94 KB
/
train18.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
This code is part of an adaptation/modification from the original project available at:
https://github.com/peterwang512/CNNDetection
The original code was created by Wang et al. and is used here under the terms of the license
specified in the original project's repository. Any use of this adapted/modified code
must respect the terms of such license.
Adaptations and modifications made by: Daniel Cabanas Gonzalez
Modification date: 08/04/2024
"""
import os
import sys
import time
import torch
import torch.nn
import argparse
from PIL import Image
from tensorboardX import SummaryWriter
from validate18 import validate
from data import create_dataloader
from earlystop import EarlyStopping
from networks.trainer18 import Trainer
from options.train_options import TrainOptions
"""Currently assumes jpg_prob, blur_prob 0 or 1"""
def get_val_opt():
val_opt = TrainOptions().parse(print_options=False)
val_opt.dataroot = '{}/{}/'.format(val_opt.dataroot, val_opt.val_split)
val_opt.isTrain = False
val_opt.no_resize = False
val_opt.no_crop = False
val_opt.serial_batches = True
val_opt.jpg_method = ['pil']
if len(val_opt.blur_sig) == 2:
b_sig = val_opt.blur_sig
val_opt.blur_sig = [(b_sig[0] + b_sig[1]) / 2]
if len(val_opt.jpg_qual) != 1:
j_qual = val_opt.jpg_qual
val_opt.jpg_qual = [int((j_qual[0] + j_qual[-1]) / 2)]
return val_opt
if __name__ == '__main__':
opt = TrainOptions().parse()
opt.dataroot = '{}/{}/'.format(opt.dataroot, opt.train_split)
val_opt = get_val_opt()
data_loader = create_dataloader(opt)
dataset_size = len(data_loader)
print('#training images = %d' % dataset_size)
total_images = sum([len(dataset) for dataset in data_loader.dataset.datasets])
print('#total images = %d' % total_images)
train_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, "train"))
val_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, "val"))
model = Trainer(opt)
early_stopping = EarlyStopping(patience=opt.earlystop_epoch, delta=-0.001, verbose=True)
for epoch in range(opt.niter):
epoch_start_time = time.time()
iter_data_time = time.time()
epoch_iter = 0
for i, data in enumerate(data_loader):
model.total_steps += 1
epoch_iter += opt.batch_size
model.set_input(data)
model.optimize_parameters()
if model.total_steps % opt.loss_freq == 0:
print("Train loss: {} at step: {}".format(model.loss, model.total_steps))
train_writer.add_scalar('loss', model.loss, model.total_steps)
if model.total_steps % opt.save_latest_freq == 0:
print('saving the latest model %s (epoch %d, model.total_steps %d)' %
(opt.name, epoch, model.total_steps))
model.save_networks('latest')
# print("Iter time: %d sec" % (time.time()-iter_data_time))
# iter_data_time = time.time()
if epoch % opt.save_epoch_freq == 0:
print('saving the model at the end of epoch %d, iters %d' %
(epoch, model.total_steps))
model.save_networks('latest')
model.save_networks(epoch)
# Validation
model.eval()
acc, ap = validate(model.model, val_opt)[:2]
val_writer.add_scalar('accuracy', acc, model.total_steps)
val_writer.add_scalar('ap', ap, model.total_steps)
print("(Val @ epoch {}) acc: {}; ap: {}".format(epoch, acc, ap))
early_stopping(acc, model)
if early_stopping.early_stop:
cont_train = model.adjust_learning_rate()
if cont_train:
print("Learning rate dropped by 10, continue training...")
early_stopping = EarlyStopping(patience=opt.earlystop_epoch, delta=-0.002, verbose=True)
else:
print("Early stopping.")
break
model.train()