-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain-rpn.py
executable file
·56 lines (47 loc) · 1.75 KB
/
train-rpn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#!/usr/bin/env python3
import os
import math
import sys
# C++ code, python3 setup.py build
sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(__file__)), 'build/lib.linux-x86_64-3.5'))
sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), 'zoo/slim'))
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
from nets import nets_factory, resnet_utils
import aardvark
import cv2
from rpn import RPN
import cpp
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('finetune', None, '')
flags.DEFINE_string('backbone', 'resnet_v2_50', 'architecture')
flags.DEFINE_integer('backbone_stride', 16, '')
class Model (RPN):
def __init__ (self):
super().__init__(FLAGS.backbone_stride)
pass
def rpn_backbone (self, images):
self.backbone = aardvark.create_stock_slim_network(FLAGS.backbone, images, self.is_training, global_pool=False, stride=FLAGS.backbone_stride, scope='bb1')
self.backbone_stride = FLAGS.backbone_stride
pass
def rpn_logits (self, channels, stride):
upscale = self.backbone_stride // stride
with slim.arg_scope(aardvark.default_argscope(self.is_training)):
return slim.conv2d_transpose(self.backbone, channels, 2*upscale, upscale, activation_fn=None)
pass
def rpn_params (self, channels, stride):
upscale = self.backbone_stride // stride
with slim.arg_scope(aardvark.default_argscope(self.is_training)):
return slim.conv2d_transpose(self.backbone, channels, 2*upscale, upscale, activation_fn=None)
pass
def main (_):
model = Model()
aardvark.train(model)
pass
if __name__ == '__main__':
try:
tf.app.run()
except KeyboardInterrupt:
pass