Skip to content
This repository has been archived by the owner on Aug 9, 2023. It is now read-only.

Commit

Permalink
Merge pull request #358 from wellcometrust/reuse-cnn
Browse files Browse the repository at this point in the history
Do not initialize model if it exists in CNN
  • Loading branch information
nsorros authored Nov 3, 2021
2 parents 2ebbdc5 + b8defc4 commit 2b712ed
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 11 deletions.
41 changes: 41 additions & 0 deletions tests/test_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,44 @@ def test_multilabel_attention():
])
model.fit(X, Y)
assert model.score(X, Y) > 0.3


def test_build_model():
X = [
"One and two",
"One only",
"Two nothing else",
"Two and three"
]
Y = np.array([
[1, 1, 0, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 1, 1, 0]
])

vectorizer = KerasVectorizer()
X_vec = vectorizer.fit_transform(X)

batch_size = 2
model = CNNClassifier(
batch_size=batch_size,
multilabel=True, learning_rate=1e-2)
model.fit(X_vec, Y)

Y_pred = model.predict(X_vec)
assert Y_pred.shape[1] == 4

Y = Y[:, :3]
sequence_length = X_vec.shape[1]
vocab_size = X_vec.max() + 1
nb_outputs = Y.shape[1]
decay_steps = X_vec.shape[0] / batch_size

model.build_model(
sequence_length, vocab_size,
nb_outputs, decay_steps)
model.fit(X_vec, Y)

Y_pred = model.predict(X_vec)
assert Y_pred.shape[1] == 3
25 changes: 14 additions & 11 deletions wellcomeml/ml/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def _init_from_data(self, X, Y):
)
return steps_per_epoch

def _build_model(self, sequence_length, vocab_size, nb_outputs,
steps_per_epoch, embedding_matrix=None):
def build_model(self, sequence_length, vocab_size, nb_outputs,
decay_steps, embedding_matrix=None):
def residual_conv_block(x1, l2):
filters = x1.shape[2]
x2 = tf.keras.layers.Conv1D(
Expand Down Expand Up @@ -256,10 +256,10 @@ def residual_attention(x1):
if self.feature_approach == "multilabel-attention":
# out shape is (batch_size, attention heads (or outputs), 1)
out = tf.keras.layers.Flatten()(out)
model = tf.keras.Model(inp, out)
self.model = tf.keras.Model(inp, out)

learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
self.learning_rate, steps_per_epoch, self.learning_rate_decay,
self.learning_rate, decay_steps, self.learning_rate_decay,
staircase=True
)

Expand All @@ -272,8 +272,8 @@ def residual_attention(x1):
METRIC_DICT[m] if m in METRIC_DICT else m
for m in self.metrics
]
model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=metrics)
return model
self.model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=metrics)
return self.model

def fit(self, X, Y=None, embedding_matrix=None, steps_per_epoch=None):
if isinstance(X, list):
Expand Down Expand Up @@ -302,11 +302,14 @@ def fit(self, X, Y=None, embedding_matrix=None, steps_per_epoch=None):
train_data = data.take(steps_per_epoch)
val_data = data.skip(steps_per_epoch)

strategy = self._get_distributed_strategy()
with strategy.scope():
self.model = self._build_model(
self.sequence_length, self.vocab_size, self.nb_outputs,
steps_per_epoch, embedding_matrix)
if hasattr(self, "model"):
logger.warning("Using existing model")
else:
strategy = self._get_distributed_strategy()
with strategy.scope():
self.build_model(
self.sequence_length, self.vocab_size, self.nb_outputs,
steps_per_epoch, embedding_matrix)

callbacks = []
if self.tensorboard_log_path:
Expand Down

0 comments on commit 2b712ed

Please sign in to comment.