Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moved logic of update of _best_ckpt_metrics before we build state dict #2007

Merged
merged 1 commit into from
Jun 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,15 @@ def _save_checkpoint(
train_metrics_titles = get_metrics_titles(self.train_metrics)
all_metrics["train"] = {metric_name: float(train_metrics_dict[metric_name]) for metric_name in train_metrics_titles}

best_checkpoint = (curr_tracked_metric > self.best_metric and self.greater_metric_to_watch_is_better) or (
curr_tracked_metric < self.best_metric and not self.greater_metric_to_watch_is_better
)

if best_checkpoint:
# STORE THE CURRENT metric AS BEST
self.best_metric = curr_tracked_metric
self._best_ckpt_metrics = all_metrics

# BUILD THE state_dict
state = {
"net": unwrap_model(self.net).state_dict(),
Expand Down Expand Up @@ -713,13 +722,7 @@ def _save_checkpoint(
self.sg_logger.add_checkpoint(tag=f"ckpt_epoch_{epoch}.pth", state_dict=state, global_step=epoch)

# OVERRIDE THE BEST CHECKPOINT AND best_metric IF metric GOT BETTER THAN THE PREVIOUS BEST
if (curr_tracked_metric > self.best_metric and self.greater_metric_to_watch_is_better) or (
curr_tracked_metric < self.best_metric and not self.greater_metric_to_watch_is_better
):
# STORE THE CURRENT metric AS BEST
self.best_metric = curr_tracked_metric

self._best_ckpt_metrics = all_metrics
if best_checkpoint:
self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)

# RUN PHASE CALLBACKS
Expand Down
Loading