Skip to content

Commit

Permalink
Merge pull request #9 from IvanKuchin/development
Browse files Browse the repository at this point in the history
Enhancement: inroduce config parameter TRAIN_PASSES_PER_VALIDATION
  • Loading branch information
IvanKuchin authored Jul 17, 2024
2 parents 48ab20d + cef929e commit 7350528
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
1 change: 1 addition & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
EPOCHS = 1000
TRAIN_PASSES_PER_VALIDATION = 2

NUMBER_OF_CONV_IN_LAYER = 2

Expand Down
8 changes: 4 additions & 4 deletions tools/categorical_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ def update_state(self, y_true, y_pred, sample_weight=None):
subj = tf.argmax(y_pred, axis = -1)

if self.reduce == "max":
self.m.assign(tf.reduce_max(subj))
self.m.assign(tf.reduce_max(tf.cast(subj, dtype=tf.float32)))
elif self.reduce == "min":
self.m.assign(tf.reduce_min(subj))
self.m.assign(tf.reduce_min(tf.cast(subj, dtype=tf.float32)))
elif self.reduce == "mean":
self.m.assign(tf.reduce_mean(subj))
self.m.assign(tf.reduce_mean(tf.cast(subj, dtype=tf.float32)))
elif self.reduce == "sum":
self.m.assign(tf.reduce_sum(subj))
self.m.assign(tf.reduce_sum(tf.cast(subj, dtype=tf.float32)))

def result(self):
return self.m
Expand Down
4 changes: 2 additions & 2 deletions train_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def main():
ds_train = craft_datasets(os.path.join(config.TFRECORD_FOLDER, "train"))
ds_valid = craft_datasets(os.path.join(config.TFRECORD_FOLDER, "valid"))

ds_train = ds_train.prefetch(1).repeat(2)
ds_train = ds_train.prefetch(1).repeat(config.TRAIN_PASSES_PER_VALIDATION)

model = craft_network(config.MODEL_CHECKPOINT)
# predict_on_random_data(model)
Expand Down Expand Up @@ -97,7 +97,7 @@ def main():
tensorboard_cb,
reduce_lr_on_plateau,
csv_logger,
# early_stopping,
early_stopping,
tf.keras.callbacks.TerminateOnNaN()],
verbose = 1,
# workers = 2
Expand Down

0 comments on commit 7350528

Please sign in to comment.