From 43cd474cb0437e878824f2090419ba18434e036c Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 24 Nov 2023 16:05:00 +0100 Subject: [PATCH] Save train time in checkpoint --- test/trainer/test_default_trainer.py | 7 +++++++ torch_em/trainer/default_trainer.py | 18 ++++++++++++++---- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/test/trainer/test_default_trainer.py b/test/trainer/test_default_trainer.py index 04cb6d0d..65251690 100644 --- a/test/trainer/test_default_trainer.py +++ b/test/trainer/test_default_trainer.py @@ -56,8 +56,11 @@ def _get_kwargs(self, with_roi=False, compile_model=False): def test_fit(self): from torch_em.trainer import DefaultTrainer + trainer = DefaultTrainer(**self._get_kwargs()) trainer.fit(10) + train_time = trainer.train_time + self.assertGreater(train_time, 0.0) save_folder = os.path.join(self.checkpoint_folder, self.name) self.assertTrue(os.path.exists(save_folder)) @@ -69,6 +72,7 @@ def test_fit(self): trainer.fit(2) self.assertEqual(trainer.iteration, 12) + self.assertGreater(trainer.train_time, train_time) trainer = DefaultTrainer(**self._get_kwargs()) trainer.fit(8, load_from_checkpoint="latest") @@ -76,6 +80,7 @@ def test_fit(self): def test_from_checkpoint(self): from torch_em.trainer import DefaultTrainer + trainer = DefaultTrainer(**self._get_kwargs(with_roi=True)) trainer.fit(10) exp_model = trainer.model @@ -86,6 +91,7 @@ def test_from_checkpoint(self): name="latest" ) self.assertEqual(trainer.iteration, trainer2.iteration) + self.assertEqual(trainer.train_time, trainer2.train_time) self.assertEqual(trainer2.train_loader.dataset.raw.shape, exp_data_shape) self.assertTrue(torch_em.util.model_is_equal(exp_model, trainer2.model)) @@ -100,6 +106,7 @@ def test_from_checkpoint(self): @unittest.skipIf(sys.version_info.minor > 10, "Not supported for python > 3.10") def test_compiled_model(self): from torch_em.trainer import DefaultTrainer + trainer = DefaultTrainer(**self._get_kwargs(compile_model=True)) trainer.fit(10) exp_model = trainer.model diff --git a/torch_em/trainer/default_trainer.py b/torch_em/trainer/default_trainer.py index 83f958ed..36ded336 100644 --- a/torch_em/trainer/default_trainer.py +++ b/torch_em/trainer/default_trainer.py @@ -69,6 +69,7 @@ def __init__( self.mixed_precision = mixed_precision self.early_stopping = early_stopping + self.train_time = 0.0 self.scaler = amp.GradScaler() if mixed_precision else None @@ -457,7 +458,7 @@ def _initialize(self, iterations, load_from_checkpoint, epochs=None): best_metric = np.inf return best_metric - def save_checkpoint(self, name, best_metric, **extra_save_dict): + def save_checkpoint(self, name, best_metric, train_time=0.0, **extra_save_dict): save_path = os.path.join(self.checkpoint_folder, f"{name}.pt") extra_init_dict = extra_save_dict.pop("init", {}) save_dict = { @@ -468,6 +469,7 @@ def save_checkpoint(self, name, best_metric, **extra_save_dict): "model_state": self.model.state_dict(), "optimizer_state": self.optimizer.state_dict(), "init": self.init_data | extra_init_dict, + "train_time": train_time, } save_dict.update(**extra_save_dict) if self.scaler is not None: @@ -492,6 +494,7 @@ def load_checkpoint(self, checkpoint="best"): self._epoch = save_dict["epoch"] self._best_epoch = save_dict["best_epoch"] self.best_metric = save_dict["best_metric"] + self.train_time = save_dict.get("train_time", 0.0) model_state = save_dict["model_state"] # to enable loading compiled models @@ -549,6 +552,7 @@ def fit(self, iterations=None, load_from_checkpoint=None, epochs=None, save_ever msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f" train_epochs = self.max_epoch - self._epoch + t_start = time.time() for _ in range(train_epochs): # run training and validation for this epoch @@ -561,19 +565,22 @@ def fit(self, iterations=None, load_from_checkpoint=None, epochs=None, save_ever if self.lr_scheduler is not None: self.lr_scheduler.step(current_metric) + # how long did we train in total? + total_train_time = (time.time() - t_start) + self.train_time + # save this checkpoint as the new best checkpoint if # it has the best overall validation metric if current_metric < best_metric: best_metric = current_metric self._best_epoch = self._epoch - self.save_checkpoint("best", best_metric) + self.save_checkpoint("best", best_metric, train_time=total_train_time) # save this checkpoint as the latest checkpoint - self.save_checkpoint("latest", best_metric) + self.save_checkpoint("latest", best_metric, train_time=total_train_time) # if we save after every k-th epoch then check if we need to save now if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0: - self.save_checkpoint(f"epoch-{self._epoch + 1}", best_metric) + self.save_checkpoint(f"epoch-{self._epoch + 1}", best_metric, train_time=total_train_time) # if early stopping has been specified then check if the stopping condition is met if self.early_stopping is not None: @@ -591,6 +598,9 @@ def fit(self, iterations=None, load_from_checkpoint=None, epochs=None, save_ever if self._generate_name: self.name = None + # Update the train time + self.train_time = total_train_time + # TODO save the model to wandb if we have the wandb logger if isinstance(self.logger, WandbLogger): self.logger.get_wandb().finish()