Skip to content

Commit

Permalink
Added solution to Issue #88
Browse files Browse the repository at this point in the history
  • Loading branch information
abhi-kumar committed Jul 28, 2020
1 parent 8959e23 commit c2e947c
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions monk/tf_keras_1/finetune/level_3_training_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,16 @@ def set_training_final(self):

if(self.system_dict["training"]["settings"]["save_training_logs"]):
history_df = pd.read_csv(self.system_dict["log_dir"] + "/model_history_log.csv");
val_acc_history = history_df['val_acc'].tolist();
train_acc_history = history_df['acc'].tolist();
val_loss_history = history_df['val_loss'].tolist();
train_loss_history = history_df['loss'].tolist();
if(int(keras.__version__.split(".")[1]) > 2):
val_acc_history = history_df['val_accuracy'].tolist();
train_acc_history = history_df['accuracy'].tolist();
val_loss_history = history_df['val_loss'].tolist();
train_loss_history = history_df['loss'].tolist();
else:
val_acc_history = history_df['val_acc'].tolist();
train_acc_history = history_df['acc'].tolist();
val_loss_history = history_df['val_loss'].tolist();
train_loss_history = history_df['loss'].tolist();

f = open(self.system_dict["log_dir"] + "/times.txt", 'r');
lines = f.readlines();
Expand Down Expand Up @@ -189,7 +195,7 @@ def set_training_final(self):

self.system_dict["training"]["outputs"]["training_time"] = "{:.0f}m {:.0f}s".format(time_elapsed_since // 60, time_elapsed_since % 60);

if(keras.__version__.split(".")[1] == "3"):
if(int(keras.__version__.split(".")[1]) > 2):
val_acc_history = history.history['val_accuracy'];
val_loss_history = history.history['val_loss'];
train_acc_history = history.history['accuracy'];
Expand Down Expand Up @@ -323,7 +329,7 @@ def set_training_final(self):

self.system_dict["training"]["outputs"]["training_time"] = "{:.0f}m {:.0f}s".format(time_elapsed_since // 60, time_elapsed_since % 60);

if(keras.__version__.split(".")[1] == "3"):
if(int(keras.__version__.split(".")[1]) > 2):
val_acc_history = history.history['val_accuracy'];
val_loss_history = history.history['val_loss'];
train_acc_history = history.history['accuracy'];
Expand Down

0 comments on commit c2e947c

Please sign in to comment.