-
Notifications
You must be signed in to change notification settings - Fork 1
/
model_loading.py
46 lines (35 loc) · 1.77 KB
/
model_loading.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import keras
from keras.models import model_from_yaml
from sell_signal_agent import SellSignalAgent
def load_model(agent, loss_function='mean_squared_error'):
# load YAML and create model
yaml_file = open(str(agent.__class__.__name__) + '.yaml', 'r')
loaded_model_yaml = yaml_file.read()
yaml_file.close()
loaded_model = model_from_yaml(loaded_model_yaml)
# load weights into new model
loaded_model.load_weights(str(agent.__class__.__name__) + ".h5")
agent.model.model = loaded_model
print(str(agent.__class__.__name__) + " - Loaded model from disk")
# compile the loaded model before further use
optimizer = keras.optimizers.RMSprop(lr=0.00025, rho=0.95, epsilon=0.01)
agent.model.model.compile(optimizer, loss=loss_function)
if isinstance(agent, SellSignalAgent):
loaded_model2 = model_from_yaml(loaded_model_yaml)
agent.model.target_model = loaded_model2
agent.model.target_model.compile(optimizer, loss=loss_function)
def save_model(agent):
# serialize model to YAML
model_yaml = agent.model.model.to_yaml()
with open(str(agent.__class__.__name__) + ".yaml", "w") as yaml_file:
yaml_file.write(model_yaml)
# serialize weights to HDF5
agent.model.model.save_weights(str(agent.__class__.__name__) + ".h5")
print("Saved " + str(agent.__class__.__name__) + " model to disk")
def save_tf_model(agent):
save_path = agent.model.saver.save(agent.model.sess, "./models/" + agent.__class__.__name__ + ".ckpt")
print(str(agent.__class__.__name__) + " - Model Saved")
def load_tf_model(agent):
# Restore variables from disk.
agent.model.saver.restore(agent.model.sess, "./models/" + agent.__class__.__name__ + ".ckpt")
print(str(agent.__class__.__name__) + " - Loaded model from disk")