-
Notifications
You must be signed in to change notification settings - Fork 18
/
train.py
executable file
·104 lines (70 loc) · 3.32 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from lib.helper.logger import logger
from lib.core.base_trainer.net_work import Train
from lib.dataset.dataietr import FaceBoxesDataIter,DataIter
from lib.core.model.facebox.net import FaceBoxes
import tensorflow as tf
import cv2
import numpy as np
from train_config import config as cfg
import setproctitle
logger.info('The trainer start')
setproctitle.setproctitle("faceboxes")
def main():
epochs=cfg.TRAIN.epoch
batch_size=cfg.TRAIN.batch_size
enable_function=False
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
# Currently, memory growth needs to be the same across GPUs
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.experimental.list_logical_devices('GPU')
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
except RuntimeError as e:
# Memory growth must be set before GPUs have been initialized
print(e)
devices = ['/device:GPU:{}'.format(i) for i in range(cfg.TRAIN.num_gpu)]
strategy = tf.distribute.MirroredStrategy(devices)
with strategy.scope():
model=FaceBoxes()
###run a time to build the model
image = np.zeros(shape=(1, 512, 512, 3), dtype=np.float32)
model.inference(image)
## load pretrained weights
if cfg.MODEL.pretrained_model is not None:
logger.info('load pretrained params from %s'%cfg.MODEL.pretrained_model)
model.load_weights(cfg.MODEL.pretrained_model)
### build trainer
trainer = Train(epochs, enable_function, model, batch_size, strategy)
### build dataiter
train_ds = DataIter(cfg.DATA.root_path, cfg.DATA.train_txt_path, True)
test_ds = DataIter(cfg.DATA.root_path, cfg.DATA.val_txt_path, False)
### it's a tensorpack data iter, produce a batch every iter
train_dataset=tf.data.Dataset.from_generator(train_ds,
output_types=(tf.float32,tf.float32,tf.float32),
output_shapes=([None,None,None,None],[None,None,None],[None,None]))
test_dataset = tf.data.Dataset.from_generator(test_ds,
output_types=(tf.float32,tf.float32,tf.float32),
output_shapes=([None,None,None,None],[None,None,None],[None,None]))
####
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
test_dataset = test_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)
## check the data
if cfg.TRAIN.vis:
for images,labels,matches in train_dist_dataset:
#images,labels,matches=one_batch
print(images)
for i in range(images.shape[0]):
example_image=np.array(images[i],dtype=np.uint8)
example_label=np.array(labels[i])
cv2.imshow('example',example_image)
cv2.waitKey(0)
##train
trainer.custom_loop(train_dist_dataset,
test_dist_dataset,
strategy)
if __name__=='__main__':
main()