diff --git a/config.py b/config.py index 19f80b0..d3c4931 100644 --- a/config.py +++ b/config.py @@ -1,4 +1,5 @@ EPOCHS = 1000 +TRAIN_PASSES_PER_VALIDATION = 2 NUMBER_OF_CONV_IN_LAYER = 2 diff --git a/tools/categorical_metrics.py b/tools/categorical_metrics.py index 9ac65db..3969da4 100644 --- a/tools/categorical_metrics.py +++ b/tools/categorical_metrics.py @@ -34,13 +34,13 @@ def update_state(self, y_true, y_pred, sample_weight=None): subj = tf.argmax(y_pred, axis = -1) if self.reduce == "max": - self.m.assign(tf.reduce_max(subj)) + self.m.assign(tf.reduce_max(tf.cast(subj, dtype=tf.float32))) elif self.reduce == "min": - self.m.assign(tf.reduce_min(subj)) + self.m.assign(tf.reduce_min(tf.cast(subj, dtype=tf.float32))) elif self.reduce == "mean": - self.m.assign(tf.reduce_mean(subj)) + self.m.assign(tf.reduce_mean(tf.cast(subj, dtype=tf.float32))) elif self.reduce == "sum": - self.m.assign(tf.reduce_sum(subj)) + self.m.assign(tf.reduce_sum(tf.cast(subj, dtype=tf.float32))) def result(self): return self.m diff --git a/train_segmentation.py b/train_segmentation.py index aef4716..c144bb5 100644 --- a/train_segmentation.py +++ b/train_segmentation.py @@ -49,7 +49,7 @@ def main(): ds_train = craft_datasets(os.path.join(config.TFRECORD_FOLDER, "train")) ds_valid = craft_datasets(os.path.join(config.TFRECORD_FOLDER, "valid")) - ds_train = ds_train.prefetch(1).repeat(2) + ds_train = ds_train.prefetch(1).repeat(config.TRAIN_PASSES_PER_VALIDATION) model = craft_network(config.MODEL_CHECKPOINT) # predict_on_random_data(model) @@ -97,7 +97,7 @@ def main(): tensorboard_cb, reduce_lr_on_plateau, csv_logger, - # early_stopping, + early_stopping, tf.keras.callbacks.TerminateOnNaN()], verbose = 1, # workers = 2