From d2549b709b9aabf440e4d181f4a7b3edacc8e411 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Tue, 28 May 2024 15:06:06 +0300 Subject: [PATCH] Moved logic of update of _best_ckpt_metrics before we build state dict for checkpoint --- .../training/sg_trainer/sg_trainer.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index 566228bc64..b5325f911e 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -679,6 +679,15 @@ def _save_checkpoint( train_metrics_titles = get_metrics_titles(self.train_metrics) all_metrics["train"] = {metric_name: float(train_metrics_dict[metric_name]) for metric_name in train_metrics_titles} + best_checkpoint = (curr_tracked_metric > self.best_metric and self.greater_metric_to_watch_is_better) or ( + curr_tracked_metric < self.best_metric and not self.greater_metric_to_watch_is_better + ) + + if best_checkpoint: + # STORE THE CURRENT metric AS BEST + self.best_metric = curr_tracked_metric + self._best_ckpt_metrics = all_metrics + # BUILD THE state_dict state = { "net": unwrap_model(self.net).state_dict(), @@ -713,13 +722,7 @@ def _save_checkpoint( self.sg_logger.add_checkpoint(tag=f"ckpt_epoch_{epoch}.pth", state_dict=state, global_step=epoch) # OVERRIDE THE BEST CHECKPOINT AND best_metric IF metric GOT BETTER THAN THE PREVIOUS BEST - if (curr_tracked_metric > self.best_metric and self.greater_metric_to_watch_is_better) or ( - curr_tracked_metric < self.best_metric and not self.greater_metric_to_watch_is_better - ): - # STORE THE CURRENT metric AS BEST - self.best_metric = curr_tracked_metric - - self._best_ckpt_metrics = all_metrics + if best_checkpoint: self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch) # RUN PHASE CALLBACKS