-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_distributed.py
101 lines (81 loc) · 3.6 KB
/
main_distributed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
'''
Reference: https://github.com/coldmanck/show-attend-and-tell
This is a distributed parallel computing version on Clusterone environment.
'''
#!/usr/bin/python
import tensorflow as tf
from config import Config
from model import CaptionGenerator
from dataset import prepare_train_data, prepare_eval_data, prepare_test_data
from clusterone_config import distributed_env
tf.flags.DEFINE_string('phase', 'train',
'The phase can be train, eval or test')
tf.flags.DEFINE_boolean('load', False,
'Turn on to load a pretrained model from either \
the latest checkpoint or a specified file')
tf.flags.DEFINE_string('model_file', None,
'If sepcified, load a pretrained model from this file')
tf.flags.DEFINE_boolean('load_cnn', False,
'Turn on to load a pretrained CNN model')
tf.flags.DEFINE_string('cnn_model_file', './vgg16_no_fc.npy',
'The file containing a pretrained CNN model')
tf.flags.DEFINE_boolean('train_cnn', False,
'Turn on to train both CNN and RNN. \
Otherwise, only RNN is trained')
tf.flags.DEFINE_integer('beam_size', 3,
'The size of beam search for caption generation')
def main(argv):
flags = tf.app.flags
FLAGS = flags.FLAGS
config = Config()
config.phase = FLAGS.phase
config.train_cnn = FLAGS.train_cnn
config.beam_size = FLAGS.beam_size
# Cluster One setting
clusterone_dist_env = distributed_env(config.root_path_to_local_data,
config.path_to_local_logs,
config.cloud_path_to_data,
config.local_repo,
config.cloud_user_repo,
flags)
clusterone_dist_env.get_env()
tf.reset_default_graph()
device, target = clusterone_dist_env.device_and_target() # getting node environment
# end of setting
# Using tensorflow's MonitoredTrainingSession to take care of checkpoints
with tf.train.MonitoredTrainingSession(
master=target,
is_chief=(FLAGS.task_index == 0),
checkpoint_dir=FLAGS.log_dir) as sess:
# with tf.Session() as sess:
if FLAGS.phase == 'train':
# training phase
data = prepare_train_data(config)
with tf.device(device): # define model
model = CaptionGenerator(config)
sess.run(tf.global_variables_initializer())
if FLAGS.load:
model.load(sess, FLAGS.model_file)
if FLAGS.load_cnn:
model.load_cnn(sess, FLAGS.cnn_model_file)
tf.get_default_graph().finalize()
model.train(sess, data)
elif FLAGS.phase == 'eval':
# evaluation phase
config.batch_size = 1
coco, data, vocabulary = prepare_eval_data(config)
with tf.device(device): # define model
model = CaptionGenerator(config)
model.load(sess, FLAGS.model_file)
tf.get_default_graph().finalize()
model.eval(sess, coco, data, vocabulary)
else:
# testing phase
data, vocabulary = prepare_test_data(config)
with tf.device(device): # define model
model = CaptionGenerator(config)
model.load(sess, FLAGS.model_file)
tf.get_default_graph().finalize()
model.test(sess, data, vocabulary)
if __name__ == '__main__':
tf.app.run()