Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
The final training monitoring val_loss with patience 40 instead of accuracy
  • Loading branch information
SoroushOskouei authored Jun 3, 2024
1 parent 1314b37 commit 2c37071
Showing 1 changed file with 44 additions and 12 deletions.
56 changes: 44 additions & 12 deletions notebooks/ManyShotTransferLearning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"from tensorflow.keras import optimizers\n",
"from tensorflow.keras import layers\n",
"from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
"from tensorflow.keras.callbacks import Callback\n",
"from tensorflow.keras.callbacks import Callback, EarlyStopping, ModelCheckpoint\n",
"from MLD import multi_lens_distortion\n",
"\n",
"os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' \n",
Expand Down Expand Up @@ -175,29 +175,61 @@
"source": [
"model.compile(optimizer=optimizers.Adamax(1e-4), loss=\"CategoricalCrossentropy\", metrics=['accuracy'])\n",
"\n",
"# considering you want to monitor accuracy:\n",
"# # considering you want to monitor accuracy:\n",
"# acc_thresh = 0.95\n",
"\n",
"acc_thresh = 0.95\n",
"# class MyCallback(Callback):\n",
"# def on_epoch_end(self, epoch, logs=None):\n",
"# logs = logs or {}\n",
"# val_accuracy = logs.get('val_accuracy')\n",
"# if val_accuracy is not None and val_accuracy > acc_thresh:\n",
"# print(f'\\nEpoch {epoch}: Early stopping as val_accuracy is {val_accuracy}')\n",
"# self.model.stop_training = True\n",
"\n",
"class MyCallback(Callback):\n",
" def on_epoch_end(self, epoch, logs=None):\n",
" logs = logs or {}\n",
" val_accuracy = logs.get('val_accuracy')\n",
" if val_accuracy is not None and val_accuracy > acc_thresh:\n",
" print(f'\\nEpoch {epoch}: Early stopping as val_accuracy is {val_accuracy}')\n",
" self.model.stop_training = True\n",
"\n",
"# history = model.fit(\n",
"# train_generator,\n",
"# steps_per_epoch=len(train_generator),\n",
"# validation_data=validation_generator,\n",
"# validation_steps=len(validation_generator),\n",
"# epochs=200,\n",
"# callbacks=[MyCallback()]\n",
"# )\n",
"\n",
"# model.save('./PWCModel/')\n",
"\n",
"\n",
"# Parameters for EarlyStopping and ModelCheckpoint\n",
"patience = 40\n",
"\n",
"# Setting up callbacks for early stopping on minimum validation loss and saving the best model\n",
"early_stopping_callback = EarlyStopping(\n",
" monitor='val_loss',\n",
" patience=patience,\n",
" verbose=1,\n",
" mode='min',\n",
" restore_best_weights=True # Restores model weights from the epoch with the best value of the monitored quantity.\n",
")\n",
"\n",
"model_checkpoint_callback = ModelCheckpoint(\n",
" './PWCModel/best_model.h5', # Path where the model will be saved\n",
" monitor='val_loss',\n",
" save_best_only=True, # Only the best model according to the validation loss is saved\n",
" mode='min',\n",
" verbose=1\n",
")\n",
"\n",
"history = model.fit(\n",
" train_generator,\n",
" steps_per_epoch=len(train_generator),\n",
" validation_data=validation_generator,\n",
" validation_steps=len(validation_generator),\n",
" epochs=200,\n",
" callbacks=[MyCallback()]\n",
" callbacks=[early_stopping_callback, model_checkpoint_callback]\n",
")\n",
"\n",
"model.save('./PWCModel/')"
"# Save the overall model after training (optional, as the best model is already saved)\n",
"model.save('./PWCModel/best_PWC_model.h5')\n"
]
},
{
Expand Down

0 comments on commit 2c37071

Please sign in to comment.