From 57a396c3d0b0e3eb70dbc57093eb2dd935a5d61a Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 25 May 2024 16:07:16 +0200 Subject: [PATCH] combine two wandb init functions --- neural_lam/utils.py | 11 +++++------ train_model.py | 4 ---- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 6e0ec15b..943fc84e 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -271,10 +271,7 @@ def init_wandb_metrics(wandb_logger): """ Set up wandb metrics to track """ - experiment = wandb_logger.experiment - experiment.define_metric("val_mean_loss", summary="min") - for step in constants.VAL_STEP_LOG_ERRORS: - experiment.define_metric(f"val_loss_unroll{step}", summary="min") + @rank_zero_only @@ -303,7 +300,6 @@ def init_wandb(args): project=constants.WANDB_PROJECT, name=run_name, config=args, - log_model=True, ) wandb.save("neural_lam/constants.py") else: @@ -317,7 +313,10 @@ def init_wandb(args): project=constants.WANDB_PROJECT, id=args.resume_run, config=args, - log_model=True, ) + experiment = logger.experiment + experiment.define_metric("val_mean_loss", summary="min") + for step in constants.VAL_STEP_LOG_ERRORS: + experiment.define_metric(f"val_loss_unroll{step}", summary="min") return logger diff --git a/train_model.py b/train_model.py index 2c579050..bca3f638 100644 --- a/train_model.py +++ b/train_model.py @@ -281,10 +281,6 @@ def main(): precision=args.precision, ) - # Only init once, on rank 0 only - if trainer.global_rank == 0: - utils.init_wandb_metrics(logger) # Do after wandb.init - if args.eval: if args.eval == "val": eval_loader = val_loader