-
Notifications
You must be signed in to change notification settings - Fork 13
/
train_O_net.py
executable file
·54 lines (50 loc) · 2.66 KB
/
train_O_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import argparse
import mxnet as mx
from core.imdb import IMDB
from train import train_net
from core.symbol import O_Net
def train_O_net(image_set, root_path, dataset_path, prefix, ctx,
pretrained, epoch, begin_epoch,
end_epoch, frequent, lr, resume):
imdb = IMDB("mtcnn", image_set, root_path, dataset_path)
gt_imdb = imdb.gt_imdb()
gt_imdb = imdb.append_flipped_images(gt_imdb)
sym = O_Net()
train_net(sym, prefix, ctx, pretrained, epoch, begin_epoch, end_epoch, gt_imdb,
48, frequent, not resume, lr)
def parse_args():
parser = argparse.ArgumentParser(description='Train O_net(48-net)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--image_set', dest='image_set', help='training set',
default='train_48', type=str)
parser.add_argument('--root_path', dest='root_path', help='output data folder',
default='data', type=str)
parser.add_argument('--dataset_path', dest='dataset_path', help='dataset folder',
default='data/mtcnn', type=str)
parser.add_argument('--prefix', dest='prefix', help='new model prefix',
default='model/onet', type=str)
parser.add_argument('--gpus', dest='gpu_ids', help='GPU device to train with',
default='0', type=str)
parser.add_argument('--pretrained', dest='pretrained', help='pretrained prefix',
default='model/onet', type=str)
parser.add_argument('--epoch', dest='epoch', help='load epoch',
default=0, type=int)
parser.add_argument('--begin_epoch', dest='begin_epoch', help='begin epoch of training',
default=0, type=int)
parser.add_argument('--end_epoch', dest='end_epoch', help='end epoch of training',
default=16, type=int)
parser.add_argument('--frequent', dest='frequent', help='frequency of logging',
default=20, type=int)
parser.add_argument('--lr', dest='lr', help='learning rate',
default=0.01, type=float)
parser.add_argument('--resume', dest='resume', help='continue training', action='store_true')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
print 'Called with argument:'
print args
ctx = [mx.gpu(int(i)) for i in args.gpu_ids.split(',')]
train_O_net(args.image_set, args.root_path, args.dataset_path, args.prefix,
ctx, args.pretrained, args.epoch, args.begin_epoch,
args.end_epoch, args.frequent, args.lr, args.resume)