From 879cfec1b49d0255ed963f44b3a9f55d42c9920a Mon Sep 17 00:00:00 2001 From: sadamov <45732287+sadamov@users.noreply.github.com> Date: Wed, 29 May 2024 16:07:36 +0200 Subject: [PATCH] Make restoration of optimizer and scheduler more robust (#17) ## Summary This pull request introduces specific enhancements to the model loading and optimizer/scheduler state restoration functionalities, improving flexibility and compatibility with multi-GPU setups. ## Detailed Changes - **Enhanced Model Loading for Multi-GPU**: Modified the model loading logic to better support multi-GPU environments by ensuring that optimizer states are only loaded when necessary and appropriate. - **Checkpoint Adjustments**: Adjusted how learning rate schedulers are restored from checkpoints to ensure they align correctly with the current training state ## Impact These changes provide users with greater control over how training states are restored and improve the script's functionality in distributed training environments. ## Testing [x] Changes have been tested in both single and multi-GPU setups ## Notes Further integration testing with different types of training configurations is recommended to fully validate the new functionalities. --------- Co-authored-by: Simon Adamov --- CHANGELOG.md | 4 ++++ neural_lam/models/ar_model.py | 10 +++++----- neural_lam/vis.py | 4 ++-- train_model.py | 12 +++--------- 4 files changed, 14 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 63feff96..061aa6bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Robust restoration of optimizer and scheduler using `ckpt_path` + [\#17](https://github.com/mllam/neural-lam/pull/17) + @sadamov + - Updated scripts and modules to use `data_config.yaml` instead of `constants.py` [\#31](https://github.com/joeloskarsson/neural-lam/pull/31) @sadamov diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 9cda9fc2..29b169d4 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -83,8 +83,8 @@ def __init__(self, args): if self.output_std: self.test_metrics["output_std"] = [] # Treat as metric - # For making restoring of optimizer state optional (slight hack) - self.opt_state = None + # For making restoring of optimizer state optional + self.restore_opt = args.restore_opt # For example plotting self.n_example_pred = args.n_example_pred @@ -97,9 +97,6 @@ def configure_optimizers(self): opt = torch.optim.AdamW( self.parameters(), lr=self.args.lr, betas=(0.9, 0.95) ) - if self.opt_state: - opt.load_state_dict(self.opt_state) - return opt @property @@ -597,3 +594,6 @@ def on_load_checkpoint(self, checkpoint): ) loaded_state_dict[new_key] = loaded_state_dict[old_key] del loaded_state_dict[old_key] + if not self.restore_opt: + opt = self.configure_optimizers() + checkpoint["optimizer_states"] = [opt.state_dict()] diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 2b6abf15..8c9ca77c 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -87,7 +87,7 @@ def plot_prediction( 1, 2, figsize=(13, 7), - subplot_kw={"projection": data_config.coords_projection()}, + subplot_kw={"projection": data_config.coords_projection}, ) # Plot pred and target @@ -136,7 +136,7 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None): fig, ax = plt.subplots( figsize=(5, 4.8), - subplot_kw={"projection": data_config.coords_projection()}, + subplot_kw={"projection": data_config.coords_projection}, ) ax.coastlines() # Add coastline outlines diff --git a/train_model.py b/train_model.py index 390da6d4..fe064384 100644 --- a/train_model.py +++ b/train_model.py @@ -265,14 +265,7 @@ def main(): # Load model parameters Use new args for model model_class = MODELS[args.model] - if args.load: - model = model_class.load_from_checkpoint(args.load, args=args) - if args.restore_opt: - # Save for later - # Unclear if this works for multi-GPU - model.opt_state = torch.load(args.load)["optimizer_states"][0] - else: - model = model_class(args) + model = model_class(args) prefix = "subset-" if args.subset_ds else "" if args.eval: @@ -327,13 +320,14 @@ def main(): ) print(f"Running evaluation on {args.eval}") - trainer.test(model=model, dataloaders=eval_loader) + trainer.test(model=model, dataloaders=eval_loader, ckpt_path=args.load) else: # Train model trainer.fit( model=model, train_dataloaders=train_loader, val_dataloaders=val_loader, + ckpt_path=args.load, )