-
Notifications
You must be signed in to change notification settings - Fork 7
/
train_model.py
210 lines (183 loc) · 6.91 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
from config import config
from jobman import DD, expand
import utils
import deep_orderless_bernoulli_nade
import os, sys, socket
import os.path
class Logger(object):
def __init__(self, stdout_file):
self.terminal = sys.stdout
self.log = stdout_file
def write(self, message):
self.terminal.write(message)
self.log.write(message)
class Unbuffered(object):
def __init__(self, stream, stdout_file):
self.stream = stream
self.log_file = open(stdout_file, "w")
def write(self, data):
self.stream.write(data)
self.stream.flush()
self.log_file.write(data) # Write the data of stdout here to a text file as well
def set_config(conf, args, add_new_key=False):
# add_new_key: if conf does not contain the key, creates it
for key in args:
if key != 'jobman':
v = args[key]
if isinstance(v, DD):
set_config(conf[key], v)
else:
if conf.has_key(key):
conf[key] = convert_from_string(v)
elif add_new_key:
# create a new key in conf
conf[key] = convert_from_string(v)
else:
raise KeyError(key)
def convert_from_string(x):
"""
Convert a string that may represent a Python item to its proper data type.
It consists in running `eval` on x, and if an error occurs, returning the
string itself.
"""
try:
return eval(x, {}, {})
except Exception:
return x
def evaluate_trained(config, state, channel):
config_path = config.load_trained.from_path + 'model_config.pkl'
epoch = config.load_trained.epoch
params_path = config.load_trained.from_path + 'model_params_e%d.pkl'%(epoch)
assert config_path is not None
assert params_path is not None
assert os.path.isfile(params_path)
assert os.path.isfile(config_path)
print 'load the config options from the best trained model'
used_config = utils.load_pkl(config_path)
action = config.load_trained.action
assert action == 1
from_path = config.load_trained.from_path
epoch = config.load_trained.epoch
save_model_path = config.load_trained.from_path
set_config(config, used_config)
config.load_trained.action = action
config.load_trained.from_path = from_path
config.load_trained.epoch = epoch
config.save_model_path = save_model_path
model_type = config.model
# set up automatically some fields in config
if config.dataset.signature == 'MNIST_binary_russ':
config[model_type].n_in = 784
config[model_type].n_out = 784
# Also copy back from config into state.
for key in config:
setattr(state, key, config[key])
print 'Model Type: %s'%model_type
print 'Host: %s' % socket.gethostname()
print 'Command: %s' % ' '.join(sys.argv)
print 'initializing data engine'
input_dtype = 'float32'
target_dtype = 'int32'
data_engine = None
deep_orderless_bernoulli_nade.evaluate_trained(state, data_engine, params_path, channel)
def continue_train(config, state, channel):
config_path = config.load_trained.from_path + 'model_config.pkl'
epoch = config.load_trained.epoch
params_path = config.load_trained.from_path + 'model_params_e%d.pkl'%(epoch)
assert config_path is not None
assert params_path is not None
assert os.path.isfile(params_path)
assert os.path.isfile(config_path)
print 'load the config options from the best trained model'
used_config = utils.load_pkl(config_path)
action = config.load_trained.action
assert action == 2
from_path = config.load_trained.from_path
epoch = config.load_trained.epoch
save_model_path = config.save_model_path
set_config(config, used_config)
config.load_trained.action = 0
config.load_trained.from_path = from_path
config.load_trained.epoch = epoch
config.save_model_path = save_model_path
model_type = config.model
# set up automatically some fields in config
if config.dataset.signature == 'MNIST_binary_russ':
config[model_type].n_in = 784
config[model_type].n_out = 784
# Also copy back from config into state.
for key in config:
setattr(state, key, config[key])
print 'Model Type: %s'%model_type
print 'Host: %s' % socket.gethostname()
print 'Command: %s' % ' '.join(sys.argv)
print 'initializing data engine'
input_dtype = 'float32'
target_dtype = 'int32'
data_engine = None
deep_orderless_bernoulli_nade.continue_train(state, data_engine, params_path, channel)
def train_from_scratch(config, state, channel):
model_type = config.model
# set up automatically some fields in config
if config.dataset.signature == 'MNIST_binary_russ':
config[model_type].n_in = 784
config[model_type].n_out = 784
# manipulate the 'state
# save the config file
save_model_path = config.save_model_path
if save_model_path == 'current':
config.save_model_path = './'
# to facilitate the use of cluster for multiple jobs
save_path = './model_config.pkl'
else:
# run locally, save locally
save_path = save_model_path + 'model_config.pkl'
utils.create_dir_if_not_exist(config.save_model_path)
# for stdout file logging
#sys.stdout = Unbuffered(sys.stdout, state.save_model_path + 'stdout.log')
print 'saving model config into %s'%save_path
utils.dump_pkl(config, save_path)
# Also copy back from config into state.
for key in config:
setattr(state, key, config[key])
print 'Model Type: %s'%model_type
print 'Host: %s' % socket.gethostname()
print 'Command: %s' % ' '.join(sys.argv)
print 'initializing data engine'
input_dtype = 'float32'
target_dtype = 'int32'
data_engine = None
deep_orderless_bernoulli_nade.train_from_scratch(state, data_engine, channel)
def main(state, channel=None):
# copy state to config
set_config(config, state)
action = config.load_trained.action
if action == 0:
# normal training
train_from_scratch(config, state, channel)
return 1
elif action == 1:
# load trained model and evaluate
evaluate_trained(config, state, channel)
return 1
elif action == 2:
# load trained model, continue training
continue_train(config, state, channel)
return 1
else:
raise NotImplementedError()
def experiment(state, channel):
# called by jobman
main(state, channel)
return channel.COMPLETE
if __name__ == '__main__':
args = {}
try:
for arg in sys.argv[1:]:
k, v = arg.split('=')
args[k] = v
except:
print 'args must be like a=X b.c=X'
exit(1)
state = expand(args)
sys.exit(main(state))