From c9e6eb53bacfb8b5c7802ef7060a10097e2b5f97 Mon Sep 17 00:00:00 2001 From: David Huggins-Daines Date: Fri, 19 Jul 2024 14:59:49 -0400 Subject: [PATCH] fix: fix early stopping --- scripts/train_rnn.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/scripts/train_rnn.py b/scripts/train_rnn.py index c292b74..29338c4 100644 --- a/scripts/train_rnn.py +++ b/scripts/train_rnn.py @@ -99,9 +99,11 @@ def __init__(self, *, min_epochs=0, **kwargs): self.min_epochs = min_epochs def on_epoch_end(self, epoch_number: int, logs: Dict): - if epoch_number < self.min_epochs: - return super().on_epoch_end(epoch_number, logs) + # Un-stop (having updated best result anyway) + if epoch_number < self.min_epochs: + self.stopped_epoch = 0 + self.model.stop_training = False def run_cv(args, all_data, featdims, feat2id, label_counts, id2label):