From 83d50bfed8ce7b6709632b9148b31a70715a3321 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 15 May 2024 15:16:02 +0200 Subject: [PATCH 1/5] Add changelog (#28) Add changelog of develops so far including those since last tagged commit (`v0.1.0`) --- CHANGELOG.md | 51 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..19ecdd41 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,51 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + + +## [unreleased](https://github.com/joeloskarsson/neural-lam/compare/v0.1.0...HEAD) + +### Added + +- new metrics (`nll` and `crps_gauss`) and `metrics` submodule, stddiv output option + [c14b6b4](https://github.com/joeloskarsson/neural-lam/commit/c14b6b4323e6b56f1f18632b6ca8b0d65c3ce36a) + @joeloskarsson + +- ability to "watch" metrics and log + [c14b6b4](https://github.com/joeloskarsson/neural-lam/commit/c14b6b4323e6b56f1f18632b6ca8b0d65c3ce36a) + @joeloskarsson + +- pre-commit setup for linting and formatting + [\#6](https://github.com/joeloskarsson/neural-lam/pull/6), [\#8](https://github.com/joeloskarsson/neural-lam/pull/8) + @sadamov, @joeloskarsson + +### Changed + +- moved batch-static features ("water cover") into forcing component return by `WeatherDataset` + [\#13](https://github.com/joeloskarsson/neural-lam/pull/13) + @joeloskarsson + +- change validation metric from `mae` to `rmse` + [c14b6b4](https://github.com/joeloskarsson/neural-lam/commit/c14b6b4323e6b56f1f18632b6ca8b0d65c3ce36a) + @joeloskarsson + +- change RMSE definition to compute sqrt after all averaging + [\#10](https://github.com/joeloskarsson/neural-lam/pull/10) + @joeloskarsson + +### Removed + +- `WeatherDataset(torch.Dataset)` no longer returns "batch-static" component of + training item (only `prev_state`, `target_state` and `forcing`), the batch static features are + instead included in forcing + [\#13](https://github.com/joeloskarsson/neural-lam/pull/13) + @joeloskarsson + + +## [v0.1.0](https://github.com/joeloskarsson/neural-lam/releases/tag/v0.1.0) + +First tagged release of `neural-lam`, matching Oskarsson et al 2023 publication +(https://arxiv.org/abs/2309.17370) From 4a97a1209e8fceadef8162f13db9e4805681d22e Mon Sep 17 00:00:00 2001 From: sadamov <45732287+sadamov@users.noreply.github.com> Date: Wed, 22 May 2024 10:22:16 +0200 Subject: [PATCH 2/5] Replace constants.py with data_config.yaml (#31) **Summary** This PR replaces the `constants.py` file with a `data_config.yaml` file. Dataset related settings can be defined by the user in the new yaml file. Training specific settings were added as additional flags to the `train_model.py` routine. All respective calls to the old files were replaced. **Rationale** - Using a Yaml file for data config gives much more flexibility for various datasets used in the community. It also facilitates the future use of forcing and boundary datasets. In a follow-up PR the dataset paths will be defined in the yaml file, removing the dependency on a pre-structured `/data` folder. - It is best practice to define user input in a yaml file, the usage of python scripts for that purpose is not common. - The old `constants.py` actually combined both constants and variables, many "constants" should rather be flags to `train_models.py` - The introduction of a new ConfigClass in `utils.py` allows for very specific queries of the yaml and calculations based thereon. This branch shows future possibilities of such a class https://github.com/joeloskarsson/neural-lam/tree/feature_dataset_yaml **Testing** Both training and evaluation of the model were succesfully tested with the `meps_example` dataset. **Note** @leifdenby Could you invite Thomas R. to this repo, in case he wanted to give his input on the yaml file? This PR should mostly serve as a basis for discussion. Maybe we should add more information to the yaml file as you outline in https://github.com/mllam/mllam-data-prep. I think we should always keep in mind how the repository will look like with realistic boundary conditions and zarr-archives as data-input. This PR solves parts of https://github.com/joeloskarsson/neural-lam/issues/23 --------- Co-authored-by: Simon Adamov --- .gitignore | 1 + CHANGELOG.md | 16 ++++- README.md | 5 +- create_grid_features.py | 12 ++-- create_mesh.py | 13 ++-- create_parameter_weights.py | 20 +++--- neural_lam/config.py | 62 ++++++++++++++++++ neural_lam/constants.py | 120 ---------------------------------- neural_lam/data_config.yaml | 64 ++++++++++++++++++ neural_lam/models/ar_model.py | 68 ++++++++++--------- neural_lam/utils.py | 7 +- neural_lam/vis.py | 30 +++++---- neural_lam/weather_dataset.py | 6 +- plot_graph.py | 11 ++-- train_model.py | 51 +++++++++++---- 15 files changed, 274 insertions(+), 212 deletions(-) create mode 100644 neural_lam/config.py delete mode 100644 neural_lam/constants.py create mode 100644 neural_lam/data_config.yaml diff --git a/.gitignore b/.gitignore index 7bb826a2..c9d914c2 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ graphs *.sif sweeps test_*.sh +.vscode ### Python ### # Byte-compiled / optimized / DLL files diff --git a/CHANGELOG.md b/CHANGELOG.md index 19ecdd41..823ac8b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,11 +5,14 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - ## [unreleased](https://github.com/joeloskarsson/neural-lam/compare/v0.1.0...HEAD) ### Added +- Replaced `constants.py` with `data_config.yaml` for data configuration management + [\#31](https://github.com/joeloskarsson/neural-lam/pull/31) + @sadamov + - new metrics (`nll` and `crps_gauss`) and `metrics` submodule, stddiv output option [c14b6b4](https://github.com/joeloskarsson/neural-lam/commit/c14b6b4323e6b56f1f18632b6ca8b0d65c3ce36a) @joeloskarsson @@ -24,6 +27,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Updated scripts and modules to use `data_config.yaml` instead of `constants.py` + [\#31](https://github.com/joeloskarsson/neural-lam/pull/31) + @sadamov + +- Added new flags in `train_model.py` for configuration previously in `constants.py` + [\#31](https://github.com/joeloskarsson/neural-lam/pull/31) + @sadamov + - moved batch-static features ("water cover") into forcing component return by `WeatherDataset` [\#13](https://github.com/joeloskarsson/neural-lam/pull/13) @joeloskarsson @@ -44,8 +55,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [\#13](https://github.com/joeloskarsson/neural-lam/pull/13) @joeloskarsson - ## [v0.1.0](https://github.com/joeloskarsson/neural-lam/releases/tag/v0.1.0) First tagged release of `neural-lam`, matching Oskarsson et al 2023 publication -(https://arxiv.org/abs/2309.17370) +() diff --git a/README.md b/README.md index 67d9d9b1..ba0bb3fe 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ Still, some restrictions are inevitable: ## A note on the limited area setting Currently we are using these models on a limited area covering the Nordic region, the so called MEPS area (see [paper](https://arxiv.org/abs/2309.17370)). There are still some parts of the code that is quite specific for the MEPS area use case. -This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants used (`neural_lam/constants.py`). +This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants set in a `data_config.yaml` file (path specified in `train_model.py --data_config` ). If there is interest to use Neural-LAM for other areas it is not a substantial undertaking to refactor the code to be fully area-agnostic. We would be happy to support such enhancements. See the issues https://github.com/joeloskarsson/neural-lam/issues/2, https://github.com/joeloskarsson/neural-lam/issues/3 and https://github.com/joeloskarsson/neural-lam/issues/4 for some initial ideas on how this could be done. @@ -104,13 +104,12 @@ The graph-related files are stored in a directory called `graphs`. ### Create remaining static features To create the remaining static files run the scripts `create_grid_features.py` and `create_parameter_weights.py`. -The main option to set for these is just which dataset to use. ## Weights & Biases Integration The project is fully integrated with [Weights & Biases](https://www.wandb.ai/) (W&B) for logging and visualization, but can just as easily be used without it. When W&B is used, training configuration, training/test statistics and plots are sent to the W&B servers and made available in an interactive web interface. If W&B is turned off, logging instead saves everything locally to a directory like `wandb/dryrun...`. -The W&B project name is set to `neural-lam`, but this can be changed in `neural_lam/constants.py`. +The W&B project name is set to `neural-lam`, but this can be changed in the flags of `train_model.py` (using argsparse). See the [W&B documentation](https://docs.wandb.ai/) for details. If you would like to login and use W&B, run: diff --git a/create_grid_features.py b/create_grid_features.py index c9038103..c3714368 100644 --- a/create_grid_features.py +++ b/create_grid_features.py @@ -6,6 +6,9 @@ import numpy as np import torch +# First-party +from neural_lam import config + def main(): """ @@ -13,14 +16,15 @@ def main(): """ parser = ArgumentParser(description="Training arguments") parser.add_argument( - "--dataset", + "--data_config", type=str, - default="meps_example", - help="Dataset to compute weights for (default: meps_example)", + default="neural_lam/data_config.yaml", + help="Path to data config file (default: neural_lam/data_config.yaml)", ) args = parser.parse_args() + config_loader = config.Config.from_file(args.data_config) - static_dir_path = os.path.join("data", args.dataset, "static") + static_dir_path = os.path.join("data", config_loader.dataset.name, "static") # -- Static grid node features -- grid_xy = torch.tensor( diff --git a/create_mesh.py b/create_mesh.py index cb524cd6..f04b4d4b 100644 --- a/create_mesh.py +++ b/create_mesh.py @@ -12,6 +12,9 @@ import torch_geometric as pyg from torch_geometric.utils.convert import from_networkx +# First-party +from neural_lam import config + def plot_graph(graph, title=None): fig, axis = plt.subplots(figsize=(8, 8), dpi=200) # W,H @@ -153,11 +156,10 @@ def prepend_node_index(graph, new_index): def main(): parser = ArgumentParser(description="Graph generation arguments") parser.add_argument( - "--dataset", + "--data_config", type=str, - default="meps_example", - help="Dataset to load grid point coordinates from " - "(default: meps_example)", + default="neural_lam/data_config.yaml", + help="Path to data config file (default: neural_lam/data_config.yaml)", ) parser.add_argument( "--graph", @@ -187,7 +189,8 @@ def main(): args = parser.parse_args() # Load grid positions - static_dir_path = os.path.join("data", args.dataset, "static") + config_loader = config.Config.from_file(args.data_config) + static_dir_path = os.path.join("data", config_loader.dataset.name, "static") graph_dir_path = os.path.join("graphs", args.graph) os.makedirs(graph_dir_path, exist_ok=True) diff --git a/create_parameter_weights.py b/create_parameter_weights.py index 494a5e81..cae1ae3e 100644 --- a/create_parameter_weights.py +++ b/create_parameter_weights.py @@ -8,7 +8,7 @@ from tqdm import tqdm # First-party -from neural_lam import constants +from neural_lam import config from neural_lam.weather_dataset import WeatherDataset @@ -18,10 +18,10 @@ def main(): """ parser = ArgumentParser(description="Training arguments") parser.add_argument( - "--dataset", + "--data_config", type=str, - default="meps_example", - help="Dataset to compute weights for (default: meps_example)", + default="neural_lam/data_config.yaml", + help="Path to data config file (default: neural_lam/data_config.yaml)", ) parser.add_argument( "--batch_size", @@ -43,7 +43,8 @@ def main(): ) args = parser.parse_args() - static_dir_path = os.path.join("data", args.dataset, "static") + config_loader = config.Config.from_file(args.data_config) + static_dir_path = os.path.join("data", config_loader.dataset.name, "static") # Create parameter weights based on height # based on fig A.1 in graph cast paper @@ -56,7 +57,10 @@ def main(): "500": 0.03, } w_list = np.array( - [w_dict[par.split("_")[-2]] for par in constants.PARAM_NAMES] + [ + w_dict[par.split("_")[-2]] + for par in config_loader.dataset.var_longnames + ] ) print("Saving parameter weights...") np.save( @@ -66,7 +70,7 @@ def main(): # Load dataset without any subsampling ds = WeatherDataset( - args.dataset, + config_loader.dataset.name, split="train", subsample_step=1, pred_length=63, @@ -113,7 +117,7 @@ def main(): # Compute mean and std.-dev. of one-step differences across the dataset print("Computing mean and std.-dev. for one-step differences...") ds_standard = WeatherDataset( - args.dataset, + config_loader.dataset.name, split="train", subsample_step=1, pred_length=63, diff --git a/neural_lam/config.py b/neural_lam/config.py new file mode 100644 index 00000000..5891ea74 --- /dev/null +++ b/neural_lam/config.py @@ -0,0 +1,62 @@ +# Standard library +import functools +from pathlib import Path + +# Third-party +import cartopy.crs as ccrs +import yaml + + +class Config: + """ + Class for loading configuration files. + + This class loads a configuration file and provides a way to access its + values as attributes. + """ + + def __init__(self, values): + self.values = values + + @classmethod + def from_file(cls, filepath): + """Load a configuration file.""" + if filepath.endswith(".yaml"): + with open(filepath, encoding="utf-8", mode="r") as file: + return cls(values=yaml.safe_load(file)) + else: + raise NotImplementedError(Path(filepath).suffix) + + def __getattr__(self, name): + keys = name.split(".") + value = self.values + for key in keys: + if key in value: + value = value[key] + else: + return None + if isinstance(value, dict): + return Config(values=value) + return value + + def __getitem__(self, key): + value = self.values[key] + if isinstance(value, dict): + return Config(values=value) + return value + + def __contains__(self, key): + return key in self.values + + def num_data_vars(self): + """Return the number of data variables for a given key.""" + return len(self.dataset.var_names) + + @functools.cached_property + def coords_projection(self): + """Return the projection.""" + proj_config = self.values["projection"] + proj_class_name = proj_config["class"] + proj_class = getattr(ccrs, proj_class_name) + proj_params = proj_config.get("kwargs", {}) + return proj_class(**proj_params) diff --git a/neural_lam/constants.py b/neural_lam/constants.py deleted file mode 100644 index 527c31d8..00000000 --- a/neural_lam/constants.py +++ /dev/null @@ -1,120 +0,0 @@ -# Third-party -import cartopy -import numpy as np - -WANDB_PROJECT = "neural-lam" - -SECONDS_IN_YEAR = ( - 365 * 24 * 60 * 60 -) # Assuming no leap years in dataset (2024 is next) - -# Log prediction error for these lead times -VAL_STEP_LOG_ERRORS = np.array([1, 2, 3, 5, 10, 15, 19]) - -# Log these metrics to wandb as scalar values for -# specific variables and lead times -# List of metrics to watch, including any prefix (e.g. val_rmse) -METRICS_WATCH = [] -# Dict with variables and lead times to log watched metrics for -# Format is a dictionary that maps from a variable index to -# a list of lead time steps -VAR_LEADS_METRICS_WATCH = { - 6: [2, 19], # t_2 - 14: [2, 19], # wvint_0 - 15: [2, 19], # z_1000 -} - -# Variable names -PARAM_NAMES = [ - "pres_heightAboveGround_0_instant", - "pres_heightAboveSea_0_instant", - "nlwrs_heightAboveGround_0_accum", - "nswrs_heightAboveGround_0_accum", - "r_heightAboveGround_2_instant", - "r_hybrid_65_instant", - "t_heightAboveGround_2_instant", - "t_hybrid_65_instant", - "t_isobaricInhPa_500_instant", - "t_isobaricInhPa_850_instant", - "u_hybrid_65_instant", - "u_isobaricInhPa_850_instant", - "v_hybrid_65_instant", - "v_isobaricInhPa_850_instant", - "wvint_entireAtmosphere_0_instant", - "z_isobaricInhPa_1000_instant", - "z_isobaricInhPa_500_instant", -] - -PARAM_NAMES_SHORT = [ - "pres_0g", - "pres_0s", - "nlwrs_0", - "nswrs_0", - "r_2", - "r_65", - "t_2", - "t_65", - "t_500", - "t_850", - "u_65", - "u_850", - "v_65", - "v_850", - "wvint_0", - "z_1000", - "z_500", -] -PARAM_UNITS = [ - "Pa", - "Pa", - "W/m\\textsuperscript{2}", - "W/m\\textsuperscript{2}", - "-", # unitless - "-", - "K", - "K", - "K", - "K", - "m/s", - "m/s", - "m/s", - "m/s", - "kg/m\\textsuperscript{2}", - "m\\textsuperscript{2}/s\\textsuperscript{2}", - "m\\textsuperscript{2}/s\\textsuperscript{2}", -] - -# Projection and grid -# Hard coded for now, but should eventually be part of dataset desc. files -GRID_SHAPE = (268, 238) # (y, x) - -LAMBERT_PROJ_PARAMS = { - "a": 6367470, - "b": 6367470, - "lat_0": 63.3, - "lat_1": 63.3, - "lat_2": 63.3, - "lon_0": 15.0, - "proj": "lcc", -} - -GRID_LIMITS = [ # In projection - -1059506.5523409774, # min x - 1310493.4476590226, # max x - -1331732.4471934352, # min y - 1338267.5528065648, # max y -] - -# Create projection -LAMBERT_PROJ = cartopy.crs.LambertConformal( - central_longitude=LAMBERT_PROJ_PARAMS["lon_0"], - central_latitude=LAMBERT_PROJ_PARAMS["lat_0"], - standard_parallels=( - LAMBERT_PROJ_PARAMS["lat_1"], - LAMBERT_PROJ_PARAMS["lat_2"], - ), -) - -# Data dimensions -GRID_FORCING_DIM = 5 * 3 + 1 # 5 feat. for 3 time-step window + 1 batch-static -GRID_STATE_DIM = 17 diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml new file mode 100644 index 00000000..f16a4a30 --- /dev/null +++ b/neural_lam/data_config.yaml @@ -0,0 +1,64 @@ +dataset: + name: meps_example + var_names: + - pres_0g + - pres_0s + - nlwrs_0 + - nswrs_0 + - r_2 + - r_65 + - t_2 + - t_65 + - t_500 + - t_850 + - u_65 + - u_850 + - v_65 + - v_850 + - wvint_0 + - z_1000 + - z_500 + var_units: + - Pa + - Pa + - r"$\mathrm{W}/\mathrm{m}^2$" + - r"$\mathrm{W}/\mathrm{m}^2$" + - "" + - "" + - K + - K + - K + - K + - m/s + - m/s + - m/s + - m/s + - r"$\mathrm{kg}/\mathrm{m}^2$" + - r"$\mathrm{m}^2/\mathrm{s}^2$" + - r"$\mathrm{m}^2/\mathrm{s}^2$" + var_longnames: + - pres_heightAboveGround_0_instant + - pres_heightAboveSea_0_instant + - nlwrs_heightAboveGround_0_accum + - nswrs_heightAboveGround_0_accum + - r_heightAboveGround_2_instant + - r_hybrid_65_instant + - t_heightAboveGround_2_instant + - t_hybrid_65_instant + - t_isobaricInhPa_500_instant + - t_isobaricInhPa_850_instant + - u_hybrid_65_instant + - u_isobaricInhPa_850_instant + - v_hybrid_65_instant + - v_isobaricInhPa_850_instant + - wvint_entireAtmosphere_0_instant + - z_isobaricInhPa_1000_instant + - z_isobaricInhPa_500_instant + num_forcing_features: 16 +grid_shape_state: [268, 238] +projection: + class: LambertConformal + kwargs: + central_longitude: 15.0 + central_latitude: 63.3 + standard_parallels: [63.3, 63.3] diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 7d0a8320..9cda9fc2 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -9,7 +9,7 @@ import wandb # First-party -from neural_lam import constants, metrics, utils, vis +from neural_lam import config, metrics, utils, vis class ARModel(pl.LightningModule): @@ -24,10 +24,13 @@ class ARModel(pl.LightningModule): def __init__(self, args): super().__init__() self.save_hyperparameters() - self.lr = args.lr + self.args = args + self.config_loader = config.Config.from_file(args.data_config) # Load static features for grid/data - static_data_dict = utils.load_static_data(args.dataset) + static_data_dict = utils.load_static_data( + self.config_loader.dataset.name + ) for static_data_name, static_data_tensor in static_data_dict.items(): self.register_buffer( static_data_name, static_data_tensor, persistent=False @@ -36,14 +39,11 @@ def __init__(self, args): # Double grid output dim. to also output std.-dev. self.output_std = bool(args.output_std) if self.output_std: - self.grid_output_dim = ( - 2 * constants.GRID_STATE_DIM - ) # Pred. dim. in grid cell + # Pred. dim. in grid cell + self.grid_output_dim = 2 * self.config_loader.num_data_vars() else: - self.grid_output_dim = ( - constants.GRID_STATE_DIM - ) # Pred. dim. in grid cell - + # Pred. dim. in grid cell + self.grid_output_dim = self.config_loader.num_data_vars() # Store constant per-variable std.-dev. weighting # Note that this is the inverse of the multiplicative weighting # in wMSE/wMAE @@ -57,11 +57,11 @@ def __init__(self, args): ( self.num_grid_nodes, grid_static_dim, - ) = self.grid_static_features.shape # 63784 = 268x238 + ) = self.grid_static_features.shape self.grid_dim = ( - 2 * constants.GRID_STATE_DIM + 2 * self.config_loader.num_data_vars() + grid_static_dim - + constants.GRID_FORCING_DIM + + self.config_loader.dataset.num_forcing_features ) # Instantiate loss function @@ -95,7 +95,7 @@ def __init__(self, args): def configure_optimizers(self): opt = torch.optim.AdamW( - self.parameters(), lr=self.lr, betas=(0.9, 0.95) + self.parameters(), lr=self.args.lr, betas=(0.9, 0.95) ) if self.opt_state: opt.load_state_dict(self.opt_state) @@ -246,7 +246,7 @@ def validation_step(self, batch, batch_idx): # Log loss per time step forward and mean val_log_dict = { f"val_loss_unroll{step}": time_step_loss[step - 1] - for step in constants.VAL_STEP_LOG_ERRORS + for step in self.args.val_steps_to_log } val_log_dict["val_mean_loss"] = mean_loss self.log_dict( @@ -294,7 +294,7 @@ def test_step(self, batch, batch_idx): # Log loss per time step forward and mean test_log_dict = { f"test_loss_unroll{step}": time_step_loss[step - 1] - for step in constants.VAL_STEP_LOG_ERRORS + for step in self.args.val_steps_to_log } test_log_dict["test_mean_loss"] = mean_loss @@ -328,7 +328,9 @@ def test_step(self, batch, batch_idx): spatial_loss = self.loss( prediction, target, pred_std, average_grid=False ) # (B, pred_steps, num_grid_nodes) - log_spatial_losses = spatial_loss[:, constants.VAL_STEP_LOG_ERRORS - 1] + log_spatial_losses = spatial_loss[ + :, [step - 1 for step in self.args.val_steps_to_log] + ] self.spatial_loss_maps.append(log_spatial_losses) # (B, N_log, num_grid_nodes) @@ -399,14 +401,15 @@ def plot_examples(self, batch, n_examples, prediction=None): pred_t[:, var_i], target_t[:, var_i], self.interior_mask[:, 0], + self.config_loader, title=f"{var_name} ({var_unit}), " - f"t={t_i} ({self.step_length*t_i} h)", + f"t={t_i} ({self.step_length * t_i} h)", vrange=var_vrange, ) for var_i, (var_name, var_unit, var_vrange) in enumerate( zip( - constants.PARAM_NAMES_SHORT, - constants.PARAM_UNITS, + self.config_loader.dataset.var_names, + self.config_loader.dataset.var_units, var_vranges, ) ) @@ -417,7 +420,7 @@ def plot_examples(self, batch, n_examples, prediction=None): { f"{var_name}_example_{example_i}": wandb.Image(fig) for var_name, fig in zip( - constants.PARAM_NAMES_SHORT, var_figs + self.config_loader.dataset.var_names, var_figs ) } ) @@ -453,7 +456,7 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): """ log_dict = {} metric_fig = vis.plot_error_map( - metric_tensor, step_length=self.step_length + metric_tensor, self.config_loader, step_length=self.step_length ) full_log_name = f"{prefix}_{metric_name}" log_dict[full_log_name] = wandb.Image(metric_fig) @@ -471,14 +474,14 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): ) # Check if metrics are watched, log exact values for specific vars - if full_log_name in constants.METRICS_WATCH: - for var_i, timesteps in constants.VAR_LEADS_METRICS_WATCH.items(): - var = constants.PARAM_NAMES_SHORT[var_i] + if full_log_name in self.args.metrics_watch: + for var_i, timesteps in self.args.var_leads_metrics_watch.items(): + var = self.config_loader.dataset.var_nums[var_i] log_dict.update( { f"{full_log_name}_{var}_step_{step}": metric_tensor[ step - 1, var_i - ] # 1-indexed in constants + ] # 1-indexed in data_config for step in timesteps } ) @@ -542,10 +545,11 @@ def on_test_epoch_end(self): vis.plot_spatial_error( loss_map, self.interior_mask[:, 0], - title=f"Test loss, t={t_i} ({self.step_length*t_i} h)", + self.config_loader, + title=f"Test loss, t={t_i} ({self.step_length * t_i} h)", ) for t_i, loss_map in zip( - constants.VAL_STEP_LOG_ERRORS, mean_spatial_loss + self.args.val_steps_to_log, mean_spatial_loss ) ] @@ -555,14 +559,14 @@ def on_test_epoch_end(self): # also make without title and save as pdf pdf_loss_map_figs = [ - vis.plot_spatial_error(loss_map, self.interior_mask[:, 0]) + vis.plot_spatial_error( + loss_map, self.interior_mask[:, 0], self.config_loader + ) for loss_map in mean_spatial_loss ] pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps") os.makedirs(pdf_loss_maps_dir, exist_ok=True) - for t_i, fig in zip( - constants.VAL_STEP_LOG_ERRORS, pdf_loss_map_figs - ): + for t_i, fig in zip(self.args.val_steps_to_log, pdf_loss_map_figs): fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf")) # save mean spatial loss as .pt file also torch.save( diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 31715502..836b04ed 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -7,9 +7,6 @@ from torch import nn from tueplots import bundles, figsizes -# First-party -from neural_lam import constants - def load_dataset_stats(dataset_name, device="cpu"): """ @@ -263,11 +260,11 @@ def fractional_plot_bundle(fraction): return bundle -def init_wandb_metrics(wandb_logger): +def init_wandb_metrics(wandb_logger, val_steps): """ Set up wandb metrics to track """ experiment = wandb_logger.experiment experiment.define_metric("val_mean_loss", summary="min") - for step in constants.VAL_STEP_LOG_ERRORS: + for step in val_steps: experiment.define_metric(f"val_loss_unroll{step}", summary="min") diff --git a/neural_lam/vis.py b/neural_lam/vis.py index cef34a84..2b6abf15 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -4,11 +4,11 @@ import numpy as np # First-party -from neural_lam import constants, utils +from neural_lam import utils @matplotlib.rc_context(utils.fractional_plot_bundle(1)) -def plot_error_map(errors, title=None, step_length=3): +def plot_error_map(errors, data_config, title=None, step_length=3): """ Plot a heatmap of errors of different variables at different predictions horizons @@ -51,7 +51,7 @@ def plot_error_map(errors, title=None, step_length=3): y_ticklabels = [ f"{name} ({unit})" for name, unit in zip( - constants.PARAM_NAMES_SHORT, constants.PARAM_UNITS + data_config.dataset.var_names, data_config.dataset.var_units ) ] ax.set_yticklabels(y_ticklabels, rotation=30, size=label_size) @@ -63,7 +63,9 @@ def plot_error_map(errors, title=None, step_length=3): @matplotlib.rc_context(utils.fractional_plot_bundle(1)) -def plot_prediction(pred, target, obs_mask, title=None, vrange=None): +def plot_prediction( + pred, target, obs_mask, data_config, title=None, vrange=None +): """ Plot example prediction and grond truth. Each has shape (N_grid,) @@ -76,23 +78,25 @@ def plot_prediction(pred, target, obs_mask, title=None, vrange=None): vmin, vmax = vrange # Set up masking of border region - mask_reshaped = obs_mask.reshape(*constants.GRID_SHAPE) + mask_reshaped = obs_mask.reshape(*data_config.grid_shape_state) pixel_alpha = ( mask_reshaped.clamp(0.7, 1).cpu().numpy() ) # Faded border region fig, axes = plt.subplots( - 1, 2, figsize=(13, 7), subplot_kw={"projection": constants.LAMBERT_PROJ} + 1, + 2, + figsize=(13, 7), + subplot_kw={"projection": data_config.coords_projection()}, ) # Plot pred and target for ax, data in zip(axes, (target, pred)): ax.coastlines() # Add coastline outlines - data_grid = data.reshape(*constants.GRID_SHAPE).cpu().numpy() + data_grid = data.reshape(*data_config.grid_shape_state).cpu().numpy() im = ax.imshow( data_grid, origin="lower", - extent=constants.GRID_LIMITS, alpha=pixel_alpha, vmin=vmin, vmax=vmax, @@ -112,7 +116,7 @@ def plot_prediction(pred, target, obs_mask, title=None, vrange=None): @matplotlib.rc_context(utils.fractional_plot_bundle(1)) -def plot_spatial_error(error, obs_mask, title=None, vrange=None): +def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None): """ Plot errors over spatial map Error and obs_mask has shape (N_grid,) @@ -125,22 +129,22 @@ def plot_spatial_error(error, obs_mask, title=None, vrange=None): vmin, vmax = vrange # Set up masking of border region - mask_reshaped = obs_mask.reshape(*constants.GRID_SHAPE) + mask_reshaped = obs_mask.reshape(*data_config.grid_shape_state) pixel_alpha = ( mask_reshaped.clamp(0.7, 1).cpu().numpy() ) # Faded border region fig, ax = plt.subplots( - figsize=(5, 4.8), subplot_kw={"projection": constants.LAMBERT_PROJ} + figsize=(5, 4.8), + subplot_kw={"projection": data_config.coords_projection()}, ) ax.coastlines() # Add coastline outlines - error_grid = error.reshape(*constants.GRID_SHAPE).cpu().numpy() + error_grid = error.reshape(*data_config.grid_shape_state).cpu().numpy() im = ax.imshow( error_grid, origin="lower", - extent=constants.GRID_LIMITS, alpha=pixel_alpha, vmin=vmin, vmax=vmax, diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index eeefc313..a782806b 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -8,7 +8,7 @@ import torch # First-party -from neural_lam import constants, utils +from neural_lam import utils class WeatherDataset(torch.utils.data.Dataset): @@ -218,9 +218,11 @@ def __getitem__(self, idx): # can roll over to next year, ok because periodicity # Encode as sin/cos + # ! Make this more flexible in a separate create_forcings.py script + seconds_in_year = 365 * 24 * 3600 hour_angle = (hour_of_day / 12) * torch.pi # (sample_len,) year_angle = ( - (second_into_year / constants.SECONDS_IN_YEAR) * 2 * torch.pi + (second_into_year / seconds_in_year) * 2 * torch.pi ) # (sample_len,) datetime_forcing = torch.stack( ( diff --git a/plot_graph.py b/plot_graph.py index 48427d5c..40b2b41d 100644 --- a/plot_graph.py +++ b/plot_graph.py @@ -7,7 +7,7 @@ import torch_geometric as pyg # First-party -from neural_lam import utils +from neural_lam import config, utils MESH_HEIGHT = 0.1 MESH_LEVEL_DIST = 0.2 @@ -20,10 +20,10 @@ def main(): """ parser = ArgumentParser(description="Plot graph") parser.add_argument( - "--dataset", + "--data_config", type=str, - default="meps_example", - help="Datast to load grid coordinates from (default: meps_example)", + default="neural_lam/data_config.yaml", + help="Path to data config file (default: neural_lam/data_config.yaml)", ) parser.add_argument( "--graph", @@ -44,6 +44,7 @@ def main(): ) args = parser.parse_args() + config_loader = config.Config.from_file(args.data_config) # Load graph data hierarchical, graph_ldict = utils.load_graph(args.graph) @@ -62,7 +63,7 @@ def main(): ) mesh_static_features = graph_ldict["mesh_static_features"] - grid_static_features = utils.load_static_data(args.dataset)[ + grid_static_features = utils.load_static_data(config_loader.dataset.name)[ "grid_static_features" ] diff --git a/train_model.py b/train_model.py index 96d21a3f..390da6d4 100644 --- a/train_model.py +++ b/train_model.py @@ -9,7 +9,7 @@ from lightning_fabric.utilities import seed # First-party -from neural_lam import constants, utils +from neural_lam import config, utils from neural_lam.models.graph_lam import GraphLAM from neural_lam.models.hi_lam import HiLAM from neural_lam.models.hi_lam_parallel import HiLAMParallel @@ -29,14 +29,11 @@ def main(): parser = ArgumentParser( description="Train or evaluate NeurWP models for LAM" ) - - # General options parser.add_argument( - "--dataset", + "--data_config", type=str, - default="meps_example", - help="Dataset, corresponding to name in data directory " - "(default: meps_example)", + default="neural_lam/data_config.yaml", + help="Path to data config file (default: neural_lam/data_config.yaml)", ) parser.add_argument( "--model", @@ -183,8 +180,36 @@ def main(): help="Number of example predictions to plot during evaluation " "(default: 1)", ) + + # Logger Settings + parser.add_argument( + "--wandb_project", + type=str, + default="neural_lam", + help="Wandb project name (default: neural_lam)", + ) + parser.add_argument( + "--val_steps_to_log", + type=list, + default=[1, 2, 3, 5, 10, 15, 19], + help="Steps to log val loss for (default: [1, 2, 3, 5, 10, 15, 19])", + ) + parser.add_argument( + "--metrics_watch", + type=list, + default=[], + help="List of metrics to watch, including any prefix (e.g. val_rmse)", + ) + parser.add_argument( + "--var_leads_metrics_watch", + type=dict, + default={}, + help="Dict with variables and lead times to log watched metrics for", + ) args = parser.parse_args() + config_loader = config.Config.from_file(args.data_config) + # Asserts for arguments assert args.model in MODELS, f"Unknown model: {args.model}" assert args.step_length <= 3, "Too high step length" @@ -203,7 +228,7 @@ def main(): # Load data train_loader = torch.utils.data.DataLoader( WeatherDataset( - args.dataset, + config_loader.dataset.name, pred_length=args.ar_steps, split="train", subsample_step=args.step_length, @@ -217,7 +242,7 @@ def main(): max_pred_length = (65 // args.step_length) - 2 # 19 val_loader = torch.utils.data.DataLoader( WeatherDataset( - args.dataset, + config_loader.dataset.name, pred_length=max_pred_length, split="val", subsample_step=args.step_length, @@ -264,7 +289,7 @@ def main(): save_last=True, ) logger = pl.loggers.WandbLogger( - project=constants.WANDB_PROJECT, name=run_name, config=args + project=args.wandb_project, name=run_name, config=args ) trainer = pl.Trainer( max_epochs=args.epochs, @@ -280,7 +305,9 @@ def main(): # Only init once, on rank 0 only if trainer.global_rank == 0: - utils.init_wandb_metrics(logger) # Do after wandb.init + utils.init_wandb_metrics( + logger, args.val_steps_to_log + ) # Do after wandb.init if args.eval: if args.eval == "val": @@ -288,7 +315,7 @@ def main(): else: # Test eval_loader = torch.utils.data.DataLoader( WeatherDataset( - args.dataset, + config_loader.dataset.name, pred_length=max_pred_length, split="test", subsample_step=args.step_length, From 5b71be3c68d815e0e376ee651c14f09d801f86de Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 22 May 2024 13:36:22 +0200 Subject: [PATCH 3/5] Simplify pre-commit setup (#29) This PR simplifies the pre-commit setup by: - removing code checks for imports against external packages, these checks should be done when running ci/cd integration tests (which will be added in a separate PR) where external dependencies are installed. After this PR is merged, pre-commit will therefore only check if code is clean and self-consistent, without the need to install external dependencies (this is the same approach taken in for example https://github.com/pydata/xarray/blob/main/.pre-commit-config.yaml). Testing imports is _also_ important, but splitting the code-cleaning out allows for much faster iteration (i.e. these linting tests will fail, avoiding the need to install all dependencies) - pinning versions of used linting tools. The current setup is errorprone because as linting tools evolve the linting rules will update. Pinning versions ensures that we all use the same versions when running `pre-commit` - moves linting tool versions to pre-commit config rather than `requirements.txt`. This allows pre-commit handle the linting in its own virtual env, without polluting the dev environment - using github action to install and run pre-commit rather than our own run instructions. This ensure that pre-commit is run identically both during local development and during ci/cd in github actions. --------- Co-authored-by: sadamov <45732287+sadamov@users.noreply.github.com> Co-authored-by: khintz --- .flake8 | 3 ++ .github/workflows/pre-commit.yml | 36 +++++-------- .pre-commit-config.yaml | 66 ++++++++++-------------- CHANGELOG.md | 11 ++++ neural_lam/models/base_hi_graph_model.py | 14 ++--- plot_graph.py | 6 +-- requirements.txt | 6 +-- 7 files changed, 63 insertions(+), 79 deletions(-) create mode 100644 .flake8 diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..b02dd545 --- /dev/null +++ b/.flake8 @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 88 +ignore = E203, F811, I002, W503 diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index a6ad84f1..dc519e5b 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -1,33 +1,25 @@ -name: Run pre-commit job +name: lint on: - push: + # trigger on pushes to any branch, but not main + push: + branches-ignore: + - main + # and also on PRs to main + pull_request: branches: - - main - pull_request: - branches: - - main + - main jobs: - pre-commit-job: + pre-commit-job: runs-on: ubuntu-latest - defaults: - run: - shell: bash -l {0} + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v2 - name: Set up Python uses: actions/setup-python@v2 with: - python-version: 3.9 - - name: Install pre-commit hooks - run: | - pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 \ - --index-url https://download.pytorch.org/whl/cpu - pip install -r requirements.txt - pip install pyg-lib==0.2.0 torch-scatter==2.1.1 torch-sparse==0.6.17 \ - torch-cluster==1.6.1 torch-geometric==2.3.1 \ - -f https://pytorch-geometric.com/whl/torch-2.0.1+cpu.html - - name: Run pre-commit hooks - run: | - pre-commit run --all-files + python-version: ${{ matrix.python-version }} + - uses: pre-commit/action@v2.0.3 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f48eca67..815a92e1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,51 +1,37 @@ repos: -- repo: https://github.com/pre-commit/pre-commit-hooks + - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: - - id: check-ast - - id: check-case-conflict - - id: check-docstring-first - - id: check-symlinks - - id: check-toml - - id: check-yaml - - id: debug-statements - - id: end-of-file-fixer - - id: trailing-whitespace -- repo: local + - id: check-ast + - id: check-case-conflict + - id: check-docstring-first + - id: check-symlinks + - id: check-toml + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/codespell-project/codespell + rev: v2.2.6 hooks: - - id: codespell - name: codespell + - id: codespell description: Check for spelling errors - language: system - entry: codespell -- repo: local + + - repo: https://github.com/psf/black + rev: 22.3.0 hooks: - - id: black - name: black + - id: black description: Format Python code - language: system - entry: black - types_or: [python, pyi] -- repo: local + + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 hooks: - - id: isort - name: isort + - id: isort description: Group and sort Python imports - language: system - entry: isort - types_or: [python, pyi, cython] -- repo: local + + - repo: https://github.com/PyCQA/flake8 + rev: 7.0.0 hooks: - - id: flake8 - name: flake8 + - id: flake8 description: Check Python code for correctness, consistency and adherence to best practices - language: system - entry: flake8 --max-line-length=80 --ignore=E203,F811,I002,W503 - types: [python] -- repo: local - hooks: - - id: pylint - name: pylint - entry: pylint -rn -sn - language: system - types: [python] diff --git a/CHANGELOG.md b/CHANGELOG.md index 823ac8b1..63feff96 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,6 +55,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [\#13](https://github.com/joeloskarsson/neural-lam/pull/13) @joeloskarsson +### Maintenance + +- simplify pre-commit setup by 1) reducing linting to only cover static + analysis excluding imports from external dependencies (this will be handled + in build/test cicd action introduced later), 2) pinning versions of linting + tools in pre-commit config (and remove from `requirements.txt`) and 3) using + github action to run pre-commit. + [\#29](https://github.com/mllam/neural-lam/pull/29) + @leifdenby + + ## [v0.1.0](https://github.com/joeloskarsson/neural-lam/releases/tag/v0.1.0) First tagged release of `neural-lam`, matching Oskarsson et al 2023 publication diff --git a/neural_lam/models/base_hi_graph_model.py b/neural_lam/models/base_hi_graph_model.py index 8ce87030..3fd30579 100644 --- a/neural_lam/models/base_hi_graph_model.py +++ b/neural_lam/models/base_hi_graph_model.py @@ -36,7 +36,7 @@ def __init__(self, args): if level_index < (self.num_levels - 1): up_edges = self.mesh_up_features[level_index].shape[0] down_edges = self.mesh_down_features[level_index].shape[0] - print(f" {level_index}<->{level_index+1}") + print(f" {level_index}<->{level_index + 1}") print(f" - {up_edges} up edges, {down_edges} down edges") # Embedders # Assume all levels have same static feature dimensionality @@ -179,9 +179,9 @@ def process_step(self, mesh_rep): ) # Update node and edge vectors in lists - mesh_rep_levels[level_l] = ( - new_node_rep # (B, num_mesh_nodes[l], d_h) - ) + mesh_rep_levels[ + level_l + ] = new_node_rep # (B, num_mesh_nodes[l], d_h) mesh_up_rep[level_l - 1] = new_edge_rep # (B, M_up[l-1], d_h) # - PROCESSOR - @@ -207,9 +207,9 @@ def process_step(self, mesh_rep): new_node_rep = gnn(send_node_rep, rec_node_rep, edge_rep) # Update node and edge vectors in lists - mesh_rep_levels[level_l] = ( - new_node_rep # (B, num_mesh_nodes[l], d_h) - ) + mesh_rep_levels[ + level_l + ] = new_node_rep # (B, num_mesh_nodes[l], d_h) # Return only bottom level representation return mesh_rep_levels[0] # (B, num_mesh_nodes[0], d_h) diff --git a/plot_graph.py b/plot_graph.py index 40b2b41d..90462194 100644 --- a/plot_graph.py +++ b/plot_graph.py @@ -48,11 +48,7 @@ def main(): # Load graph data hierarchical, graph_ldict = utils.load_graph(args.graph) - ( - g2m_edge_index, - m2g_edge_index, - m2m_edge_index, - ) = ( + (g2m_edge_index, m2g_edge_index, m2m_edge_index,) = ( graph_ldict["g2m_edge_index"], graph_ldict["m2g_edge_index"], graph_ldict["m2m_edge_index"], diff --git a/requirements.txt b/requirements.txt index 5a2111b2..f381d54f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,10 +10,6 @@ Cartopy>=0.22.0 pyproj>=3.4.1 tueplots>=0.0.8 plotly>=5.15.0 + # for dev -codespell>=2.0.0 -black>=21.9b0 -isort>=5.9.3 -flake8>=4.0.1 -pylint>=3.0.3 pre-commit>=2.15.0 From 879cfec1b49d0255ed963f44b3a9f55d42c9920a Mon Sep 17 00:00:00 2001 From: sadamov <45732287+sadamov@users.noreply.github.com> Date: Wed, 29 May 2024 16:07:36 +0200 Subject: [PATCH 4/5] Make restoration of optimizer and scheduler more robust (#17) ## Summary This pull request introduces specific enhancements to the model loading and optimizer/scheduler state restoration functionalities, improving flexibility and compatibility with multi-GPU setups. ## Detailed Changes - **Enhanced Model Loading for Multi-GPU**: Modified the model loading logic to better support multi-GPU environments by ensuring that optimizer states are only loaded when necessary and appropriate. - **Checkpoint Adjustments**: Adjusted how learning rate schedulers are restored from checkpoints to ensure they align correctly with the current training state ## Impact These changes provide users with greater control over how training states are restored and improve the script's functionality in distributed training environments. ## Testing [x] Changes have been tested in both single and multi-GPU setups ## Notes Further integration testing with different types of training configurations is recommended to fully validate the new functionalities. --------- Co-authored-by: Simon Adamov --- CHANGELOG.md | 4 ++++ neural_lam/models/ar_model.py | 10 +++++----- neural_lam/vis.py | 4 ++-- train_model.py | 12 +++--------- 4 files changed, 14 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 63feff96..061aa6bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Robust restoration of optimizer and scheduler using `ckpt_path` + [\#17](https://github.com/mllam/neural-lam/pull/17) + @sadamov + - Updated scripts and modules to use `data_config.yaml` instead of `constants.py` [\#31](https://github.com/joeloskarsson/neural-lam/pull/31) @sadamov diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 9cda9fc2..29b169d4 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -83,8 +83,8 @@ def __init__(self, args): if self.output_std: self.test_metrics["output_std"] = [] # Treat as metric - # For making restoring of optimizer state optional (slight hack) - self.opt_state = None + # For making restoring of optimizer state optional + self.restore_opt = args.restore_opt # For example plotting self.n_example_pred = args.n_example_pred @@ -97,9 +97,6 @@ def configure_optimizers(self): opt = torch.optim.AdamW( self.parameters(), lr=self.args.lr, betas=(0.9, 0.95) ) - if self.opt_state: - opt.load_state_dict(self.opt_state) - return opt @property @@ -597,3 +594,6 @@ def on_load_checkpoint(self, checkpoint): ) loaded_state_dict[new_key] = loaded_state_dict[old_key] del loaded_state_dict[old_key] + if not self.restore_opt: + opt = self.configure_optimizers() + checkpoint["optimizer_states"] = [opt.state_dict()] diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 2b6abf15..8c9ca77c 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -87,7 +87,7 @@ def plot_prediction( 1, 2, figsize=(13, 7), - subplot_kw={"projection": data_config.coords_projection()}, + subplot_kw={"projection": data_config.coords_projection}, ) # Plot pred and target @@ -136,7 +136,7 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None): fig, ax = plt.subplots( figsize=(5, 4.8), - subplot_kw={"projection": data_config.coords_projection()}, + subplot_kw={"projection": data_config.coords_projection}, ) ax.coastlines() # Add coastline outlines diff --git a/train_model.py b/train_model.py index 390da6d4..fe064384 100644 --- a/train_model.py +++ b/train_model.py @@ -265,14 +265,7 @@ def main(): # Load model parameters Use new args for model model_class = MODELS[args.model] - if args.load: - model = model_class.load_from_checkpoint(args.load, args=args) - if args.restore_opt: - # Save for later - # Unclear if this works for multi-GPU - model.opt_state = torch.load(args.load)["optimizer_states"][0] - else: - model = model_class(args) + model = model_class(args) prefix = "subset-" if args.subset_ds else "" if args.eval: @@ -327,13 +320,14 @@ def main(): ) print(f"Running evaluation on {args.eval}") - trainer.test(model=model, dataloaders=eval_loader) + trainer.test(model=model, dataloaders=eval_loader, ckpt_path=args.load) else: # Train model trainer.fit( model=model, train_dataloaders=train_loader, val_dataloaders=val_loader, + ckpt_path=args.load, ) From 9d558d1f0d343cfe6e0babaa8d9e6c45b852fe21 Mon Sep 17 00:00:00 2001 From: sadamov <45732287+sadamov@users.noreply.github.com> Date: Fri, 31 May 2024 12:12:58 +0200 Subject: [PATCH 5/5] Fix minor bugs in data_config.yaml workflow (#40) ### Summary https://github.com/mllam/neural-lam/pull/31 introduced three minor bugs that are fixed with this PR: - r"" strings are not required in units of `data_config.yaml` - dictionaries cannot be passed as argsparse, rather JSON strings. This bug is related to the flag `var_leads_metrics_watch` --------- Co-authored-by: joeloskarsson --- neural_lam/data_config.yaml | 10 +++++----- neural_lam/models/ar_model.py | 2 +- train_model.py | 13 +++++++++---- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml index f16a4a30..f1527849 100644 --- a/neural_lam/data_config.yaml +++ b/neural_lam/data_config.yaml @@ -21,8 +21,8 @@ dataset: var_units: - Pa - Pa - - r"$\mathrm{W}/\mathrm{m}^2$" - - r"$\mathrm{W}/\mathrm{m}^2$" + - $\mathrm{W}/\mathrm{m}^2$ + - $\mathrm{W}/\mathrm{m}^2$ - "" - "" - K @@ -33,9 +33,9 @@ dataset: - m/s - m/s - m/s - - r"$\mathrm{kg}/\mathrm{m}^2$" - - r"$\mathrm{m}^2/\mathrm{s}^2$" - - r"$\mathrm{m}^2/\mathrm{s}^2$" + - $\mathrm{kg}/\mathrm{m}^2$ + - $\mathrm{m}^2/\mathrm{s}^2$ + - $\mathrm{m}^2/\mathrm{s}^2$ var_longnames: - pres_heightAboveGround_0_instant - pres_heightAboveSea_0_instant diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 29b169d4..6ced211f 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -473,7 +473,7 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): # Check if metrics are watched, log exact values for specific vars if full_log_name in self.args.metrics_watch: for var_i, timesteps in self.args.var_leads_metrics_watch.items(): - var = self.config_loader.dataset.var_nums[var_i] + var = self.config_loader.dataset.var_names[var_i] log_dict.update( { f"{full_log_name}_{var}_step_{step}": metric_tensor[ diff --git a/train_model.py b/train_model.py index fe064384..cbd787f0 100644 --- a/train_model.py +++ b/train_model.py @@ -1,4 +1,5 @@ # Standard library +import json import random import time from argparse import ArgumentParser @@ -196,17 +197,21 @@ def main(): ) parser.add_argument( "--metrics_watch", - type=list, + nargs="+", default=[], help="List of metrics to watch, including any prefix (e.g. val_rmse)", ) parser.add_argument( "--var_leads_metrics_watch", - type=dict, - default={}, - help="Dict with variables and lead times to log watched metrics for", + type=str, + default="{}", + help="""JSON string with variable-IDs and lead times to log watched + metrics (e.g. '{"1": [1, 2], "3": [3, 4]}')""", ) args = parser.parse_args() + args.var_leads_metrics_watch = { + int(k): v for k, v in json.loads(args.var_leads_metrics_watch).items() + } config_loader = config.Config.from_file(args.data_config)