diff --git a/scripts/train_rnn.py b/scripts/train_rnn.py index c292b74..29338c4 100644 --- a/scripts/train_rnn.py +++ b/scripts/train_rnn.py @@ -99,9 +99,11 @@ def __init__(self, *, min_epochs=0, **kwargs): self.min_epochs = min_epochs def on_epoch_end(self, epoch_number: int, logs: Dict): - if epoch_number < self.min_epochs: - return super().on_epoch_end(epoch_number, logs) + # Un-stop (having updated best result anyway) + if epoch_number < self.min_epochs: + self.stopped_epoch = 0 + self.model.stop_training = False def run_cv(args, all_data, featdims, feat2id, label_counts, id2label):