-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
159 lines (137 loc) · 6.17 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from keras.layers import Input, Dense, Conv2D, MaxPool2D, Flatten, concatenate, LSTM, Dropout, BatchNormalization, GlobalAveragePooling2D
from keras.models import Model
from keras.applications.xception import Xception
from keras.applications.mobilenet import MobileNet
from keras.applications.resnet50 import ResNet50
from keras.optimizers import Adam, SGD
from keras.layers import Embedding
import keras.regularizers as regularizers
import keras.losses as losses
import keras.backend as K
from keras.models import load_model as load_keras_model
class Config:
numeric_input_size = 2
img_shape = (224,224,3)
n_classes = 100
batch_size = 32
embed_dim = 50
max_seq_len = 30
def __init__(self, word_index, embedding_matrix, tokenizer, lr=0.001, n_recurrent_layers=1, n_numeric_layers=3,
trainable_convnet_layers=20, imagenet_weights=True, n_top_hidden_layers=1, n_convnet_fc_layers=2,
n_classes=100, drop_prob=0.5, reg_weight=0.01, img_only=False, numeric_input_size=2, freeze_cnn=True,
numeric_only=False, rnn_only=False, distance_weight=0.001):
self.word_index = word_index
self.embedding_matrix = embedding_matrix
self.vocab_size = len(word_index)
self.embed_dim = embedding_matrix.shape[1]
self.lr = lr
self.n_recurrent_layers = n_recurrent_layers
self.n_numeric_layers = n_numeric_layers
self.trainable_convnet_layers = trainable_convnet_layers
self.imagenet_weights = imagenet_weights
self.n_top_hidden_layers = n_top_hidden_layers
self.n_convnet_fc_layers = n_convnet_fc_layers
self.n_classes = n_classes
self.drop_prob = drop_prob
self.reg_weight = reg_weight
self.tokenizer = tokenizer
self.img_only=img_only
self.numeric_input_size = numeric_input_size
self.freeze_cnn = freeze_cnn
self.numeric_only = numeric_only
self.rnn_only = rnn_only
self.distance_weight = distance_weight
def build_model(config):
img_inputs = Input(shape=config.img_shape, name='img_input')
numeric_inputs = Input(shape=(config.numeric_input_size,))
text_inputs = Input(shape=(config.max_seq_len,))
#running cnn
if config.imagenet_weights:
weights = 'imagenet'
else:
weights = None
image_model = ResNet50(include_top=False, weights=weights)
#freeze lower layers
cnn_out = image_model.output
x = GlobalAveragePooling2D()(cnn_out)
x = Dense(512, activation='relu', kernel_regularizer=regularizers.l2(config.reg_weight))(x)
x = Dropout(config.drop_prob)(x)
cnn_out = x
#running fc
x = Dense(512, activation='relu', kernel_regularizer=regularizers.l2(config.reg_weight))(numeric_inputs)
x = Dense(256, activation='relu', kernel_regularizer=regularizers.l2(config.reg_weight))(x)
x = Dropout(config.drop_prob)(x)
x = Dense(256, activation='relu', kernel_regularizer=regularizers.l2(config.reg_weight))(x)
x = Dense(128, activation='relu', kernel_regularizer=regularizers.l2(config.reg_weight))(x)
x = Dropout(config.drop_prob)(x)
fc_out = x
#running RNN
embedding_layer = Embedding(len(config.word_index) + 1, config.embed_dim,
weights=[config.embedding_matrix],
trainable=False)
embedded_seqs = embedding_layer(text_inputs)
lstm = LSTM(64)(embedded_seqs)
for i in range(config.n_recurrent_layers - 1):
lstm = LSTM(32, kernel_regularizer=regularizers.l2(config.reg_weight))(lstm)
rnn_out = lstm
#rnn_out
# to top layer of text network
#concat them
if config.img_only:
x = cnn_out
elif config.numeric_only:
x = fc_out
elif config.rnn_only:
x = rnn_out
else:
x = concatenate([cnn_out, fc_out, rnn_out])
predictions = Dense(config.n_classes, activation='softmax', name='main_output', kernel_regularizer=regularizers.l2(config.reg_weight))(x)
model = Model(inputs=[numeric_inputs, image_model.input, text_inputs], outputs=predictions)
if False:#weights is not None and config.freeze_cnn:
for i in range(len(image_model.layers) - config.trainable_convnet_layers):
image_model.layers[i].trainable = False
def custom_loss(y_true, y_pred):
epsilon = 0.001
main_loss = losses.sparse_categorical_crossentropy(y_true, y_pred)
pred_indices = K.argmax(y_pred, axis=-1)
pred_indices = K.cast(pred_indices, dtype='float32')
distance_penalty = K.constant(1.0, dtype='float32') / (K.abs(pred_indices - K.constant(config.n_classes / 2.0, dtype='float32')) + epsilon)
return main_loss + config.distance_weight * distance_penalty
def loss_with_var(y_true, y_pred):
main_loss = losses.sparse_categorical_crossentropy(y_true, y_pred)
pred_indices = K.argmax(y_pred, axis=-1)
pred_indices = K.cast(pred_indices, dtype='float32')
var_penalty = K.var(pred_indices)
return main_loss + config.distance_weight * var_penalty
opt = Adam(lr=config.lr)
model.compile(optimizer=opt, loss=loss_with_var, metrics=['sparse_categorical_accuracy', 'sparse_top_k_categorical_accuracy'])
return model
import os
import pickle
def write_model(model, config, best_val_loss, model_folder):
model.save(model_folder + 'model.h5'.format(best_val_loss))
with open(model_folder + 'config'.format(best_val_loss), 'wb') as pickle_file:
pickle.dump(config, pickle_file)
def load_config(folder):
'''
Loads config object from folder and prints warning in case of failure
:param folder: path to load from, including terminal slash
:return: config object
'''
try:
with open(folder + 'config', 'rb') as pickle_file:
config = pickle.load(pickle_file)
except:
config = None
print('Warning: failed to load config')
return config
def load_model(name):
'''
Loads saved model and config object
:param name: Name of model saved from a previous training run
:return: keras model object and config object
'''
path = 'models/' + name + '/'
model = load_keras_model(path + 'model')
config = load_config(path)
return model, config