Skip to content

Commit

Permalink
Make restoration of optimizer and scheduler more robust (#17)
Browse files Browse the repository at this point in the history
## Summary
This pull request introduces specific enhancements to the model loading
and optimizer/scheduler state restoration functionalities, improving
flexibility and compatibility with multi-GPU setups.

## Detailed Changes
- **Enhanced Model Loading for Multi-GPU**: Modified the model loading
logic to better support multi-GPU environments by ensuring that
optimizer states are only loaded when necessary and appropriate.
- **Checkpoint Adjustments**: Adjusted how learning rate schedulers are
restored from checkpoints to ensure they align correctly with the
current training state

## Impact
These changes provide users with greater control over how training
states are restored and improve the script's functionality in
distributed training environments.

## Testing
[x] Changes have been tested in both single and multi-GPU setups

## Notes
Further integration testing with different types of training
configurations is recommended to fully validate the new functionalities.

---------

Co-authored-by: Simon Adamov <simon.adamov@mailbox.org>
  • Loading branch information
sadamov and Simon Adamov authored May 29, 2024
1 parent 5b71be3 commit 879cfec
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 16 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Robust restoration of optimizer and scheduler using `ckpt_path`
[\#17](https://github.com/mllam/neural-lam/pull/17)
@sadamov

- Updated scripts and modules to use `data_config.yaml` instead of `constants.py`
[\#31](https://github.com/joeloskarsson/neural-lam/pull/31)
@sadamov
Expand Down
10 changes: 5 additions & 5 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def __init__(self, args):
if self.output_std:
self.test_metrics["output_std"] = [] # Treat as metric

# For making restoring of optimizer state optional (slight hack)
self.opt_state = None
# For making restoring of optimizer state optional
self.restore_opt = args.restore_opt

# For example plotting
self.n_example_pred = args.n_example_pred
Expand All @@ -97,9 +97,6 @@ def configure_optimizers(self):
opt = torch.optim.AdamW(
self.parameters(), lr=self.args.lr, betas=(0.9, 0.95)
)
if self.opt_state:
opt.load_state_dict(self.opt_state)

return opt

@property
Expand Down Expand Up @@ -597,3 +594,6 @@ def on_load_checkpoint(self, checkpoint):
)
loaded_state_dict[new_key] = loaded_state_dict[old_key]
del loaded_state_dict[old_key]
if not self.restore_opt:
opt = self.configure_optimizers()
checkpoint["optimizer_states"] = [opt.state_dict()]
4 changes: 2 additions & 2 deletions neural_lam/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def plot_prediction(
1,
2,
figsize=(13, 7),
subplot_kw={"projection": data_config.coords_projection()},
subplot_kw={"projection": data_config.coords_projection},
)

# Plot pred and target
Expand Down Expand Up @@ -136,7 +136,7 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None):

fig, ax = plt.subplots(
figsize=(5, 4.8),
subplot_kw={"projection": data_config.coords_projection()},
subplot_kw={"projection": data_config.coords_projection},
)

ax.coastlines() # Add coastline outlines
Expand Down
12 changes: 3 additions & 9 deletions train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,14 +265,7 @@ def main():

# Load model parameters Use new args for model
model_class = MODELS[args.model]
if args.load:
model = model_class.load_from_checkpoint(args.load, args=args)
if args.restore_opt:
# Save for later
# Unclear if this works for multi-GPU
model.opt_state = torch.load(args.load)["optimizer_states"][0]
else:
model = model_class(args)
model = model_class(args)

prefix = "subset-" if args.subset_ds else ""
if args.eval:
Expand Down Expand Up @@ -327,13 +320,14 @@ def main():
)

print(f"Running evaluation on {args.eval}")
trainer.test(model=model, dataloaders=eval_loader)
trainer.test(model=model, dataloaders=eval_loader, ckpt_path=args.load)
else:
# Train model
trainer.fit(
model=model,
train_dataloaders=train_loader,
val_dataloaders=val_loader,
ckpt_path=args.load,
)


Expand Down

0 comments on commit 879cfec

Please sign in to comment.