-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain-cls-slim.py
executable file
·42 lines (32 loc) · 1.05 KB
/
train-cls-slim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#!/usr/bin/env python3
import os
import sys
sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), 'zoo/slim'))
import tensorflow as tf
import tensorflow.contrib.slim as slim
from nets import nets_factory
import aardvark
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('finetune', None, '')
flags.DEFINE_string('net', 'resnet_v2_50', 'architecture')
class Model (aardvark.ClassificationModel):
def __init__ (self):
super().__init__()
pass
def inference (self, images, classes, is_training):
logits = aardvark.create_stock_slim_network(FLAGS.net, images, is_training, num_classes=classes, global_pool=True)
if FLAGS.finetune:
assert FLAGS.colorspace == 'RGB'
self.init_session, self.variables_to_train = aardvark.setup_finetune(FLAGS.finetune, lambda x: 'logits' in x)
return logits
pass
def main (_):
model = Model()
aardvark.train(model)
pass
if __name__ == '__main__':
try:
tf.app.run()
except KeyboardInterrupt:
pass