Skip to content

Commit

Permalink
combine two wandb init functions
Browse files Browse the repository at this point in the history
  • Loading branch information
sadamov committed May 25, 2024
1 parent 72272bc commit 57a396c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 10 deletions.
11 changes: 5 additions & 6 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,7 @@ def init_wandb_metrics(wandb_logger):
"""
Set up wandb metrics to track
"""
experiment = wandb_logger.experiment
experiment.define_metric("val_mean_loss", summary="min")
for step in constants.VAL_STEP_LOG_ERRORS:
experiment.define_metric(f"val_loss_unroll{step}", summary="min")



@rank_zero_only
Expand Down Expand Up @@ -303,7 +300,6 @@ def init_wandb(args):
project=constants.WANDB_PROJECT,
name=run_name,
config=args,
log_model=True,
)
wandb.save("neural_lam/constants.py")
else:
Expand All @@ -317,7 +313,10 @@ def init_wandb(args):
project=constants.WANDB_PROJECT,
id=args.resume_run,
config=args,
log_model=True,
)
experiment = logger.experiment
experiment.define_metric("val_mean_loss", summary="min")
for step in constants.VAL_STEP_LOG_ERRORS:
experiment.define_metric(f"val_loss_unroll{step}", summary="min")

return logger
4 changes: 0 additions & 4 deletions train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,6 @@ def main():
precision=args.precision,
)

# Only init once, on rank 0 only
if trainer.global_rank == 0:
utils.init_wandb_metrics(logger) # Do after wandb.init

if args.eval:
if args.eval == "val":
eval_loader = val_loader
Expand Down

0 comments on commit 57a396c

Please sign in to comment.