Skip to content

Commit

Permalink
Save train time in checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Nov 24, 2023
1 parent 3734d19 commit 43cd474
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
7 changes: 7 additions & 0 deletions test/trainer/test_default_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,11 @@ def _get_kwargs(self, with_roi=False, compile_model=False):

def test_fit(self):
from torch_em.trainer import DefaultTrainer

trainer = DefaultTrainer(**self._get_kwargs())
trainer.fit(10)
train_time = trainer.train_time
self.assertGreater(train_time, 0.0)

save_folder = os.path.join(self.checkpoint_folder, self.name)
self.assertTrue(os.path.exists(save_folder))
Expand All @@ -69,13 +72,15 @@ def test_fit(self):

trainer.fit(2)
self.assertEqual(trainer.iteration, 12)
self.assertGreater(trainer.train_time, train_time)

trainer = DefaultTrainer(**self._get_kwargs())
trainer.fit(8, load_from_checkpoint="latest")
self.assertEqual(trainer.iteration, 20)

def test_from_checkpoint(self):
from torch_em.trainer import DefaultTrainer

trainer = DefaultTrainer(**self._get_kwargs(with_roi=True))
trainer.fit(10)
exp_model = trainer.model
Expand All @@ -86,6 +91,7 @@ def test_from_checkpoint(self):
name="latest"
)
self.assertEqual(trainer.iteration, trainer2.iteration)
self.assertEqual(trainer.train_time, trainer2.train_time)
self.assertEqual(trainer2.train_loader.dataset.raw.shape, exp_data_shape)
self.assertTrue(torch_em.util.model_is_equal(exp_model, trainer2.model))

Expand All @@ -100,6 +106,7 @@ def test_from_checkpoint(self):
@unittest.skipIf(sys.version_info.minor > 10, "Not supported for python > 3.10")
def test_compiled_model(self):
from torch_em.trainer import DefaultTrainer

trainer = DefaultTrainer(**self._get_kwargs(compile_model=True))
trainer.fit(10)
exp_model = trainer.model
Expand Down
18 changes: 14 additions & 4 deletions torch_em/trainer/default_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(

self.mixed_precision = mixed_precision
self.early_stopping = early_stopping
self.train_time = 0.0

self.scaler = amp.GradScaler() if mixed_precision else None

Expand Down Expand Up @@ -457,7 +458,7 @@ def _initialize(self, iterations, load_from_checkpoint, epochs=None):
best_metric = np.inf
return best_metric

def save_checkpoint(self, name, best_metric, **extra_save_dict):
def save_checkpoint(self, name, best_metric, train_time=0.0, **extra_save_dict):
save_path = os.path.join(self.checkpoint_folder, f"{name}.pt")
extra_init_dict = extra_save_dict.pop("init", {})
save_dict = {
Expand All @@ -468,6 +469,7 @@ def save_checkpoint(self, name, best_metric, **extra_save_dict):
"model_state": self.model.state_dict(),
"optimizer_state": self.optimizer.state_dict(),
"init": self.init_data | extra_init_dict,
"train_time": train_time,
}
save_dict.update(**extra_save_dict)
if self.scaler is not None:
Expand All @@ -492,6 +494,7 @@ def load_checkpoint(self, checkpoint="best"):
self._epoch = save_dict["epoch"]
self._best_epoch = save_dict["best_epoch"]
self.best_metric = save_dict["best_metric"]
self.train_time = save_dict.get("train_time", 0.0)

model_state = save_dict["model_state"]
# to enable loading compiled models
Expand Down Expand Up @@ -549,6 +552,7 @@ def fit(self, iterations=None, load_from_checkpoint=None, epochs=None, save_ever
msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f"

train_epochs = self.max_epoch - self._epoch
t_start = time.time()
for _ in range(train_epochs):

# run training and validation for this epoch
Expand All @@ -561,19 +565,22 @@ def fit(self, iterations=None, load_from_checkpoint=None, epochs=None, save_ever
if self.lr_scheduler is not None:
self.lr_scheduler.step(current_metric)

# how long did we train in total?
total_train_time = (time.time() - t_start) + self.train_time

# save this checkpoint as the new best checkpoint if
# it has the best overall validation metric
if current_metric < best_metric:
best_metric = current_metric
self._best_epoch = self._epoch
self.save_checkpoint("best", best_metric)
self.save_checkpoint("best", best_metric, train_time=total_train_time)

# save this checkpoint as the latest checkpoint
self.save_checkpoint("latest", best_metric)
self.save_checkpoint("latest", best_metric, train_time=total_train_time)

# if we save after every k-th epoch then check if we need to save now
if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0:
self.save_checkpoint(f"epoch-{self._epoch + 1}", best_metric)
self.save_checkpoint(f"epoch-{self._epoch + 1}", best_metric, train_time=total_train_time)

# if early stopping has been specified then check if the stopping condition is met
if self.early_stopping is not None:
Expand All @@ -591,6 +598,9 @@ def fit(self, iterations=None, load_from_checkpoint=None, epochs=None, save_ever
if self._generate_name:
self.name = None

# Update the train time
self.train_time = total_train_time

# TODO save the model to wandb if we have the wandb logger
if isinstance(self.logger, WandbLogger):
self.logger.get_wandb().finish()
Expand Down

0 comments on commit 43cd474

Please sign in to comment.