diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..b02dd545 --- /dev/null +++ b/.flake8 @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 88 +ignore = E203, F811, I002, W503 diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index a6ad84f1..dc519e5b 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -1,33 +1,25 @@ -name: Run pre-commit job +name: lint on: - push: + # trigger on pushes to any branch, but not main + push: + branches-ignore: + - main + # and also on PRs to main + pull_request: branches: - - main - pull_request: - branches: - - main + - main jobs: - pre-commit-job: + pre-commit-job: runs-on: ubuntu-latest - defaults: - run: - shell: bash -l {0} + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v2 - name: Set up Python uses: actions/setup-python@v2 with: - python-version: 3.9 - - name: Install pre-commit hooks - run: | - pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 \ - --index-url https://download.pytorch.org/whl/cpu - pip install -r requirements.txt - pip install pyg-lib==0.2.0 torch-scatter==2.1.1 torch-sparse==0.6.17 \ - torch-cluster==1.6.1 torch-geometric==2.3.1 \ - -f https://pytorch-geometric.com/whl/torch-2.0.1+cpu.html - - name: Run pre-commit hooks - run: | - pre-commit run --all-files + python-version: ${{ matrix.python-version }} + - uses: pre-commit/action@v2.0.3 diff --git a/.gitignore b/.gitignore index 7bb826a2..c9d914c2 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ graphs *.sif sweeps test_*.sh +.vscode ### Python ### # Byte-compiled / optimized / DLL files diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f48eca67..815a92e1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,51 +1,37 @@ repos: -- repo: https://github.com/pre-commit/pre-commit-hooks + - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: - - id: check-ast - - id: check-case-conflict - - id: check-docstring-first - - id: check-symlinks - - id: check-toml - - id: check-yaml - - id: debug-statements - - id: end-of-file-fixer - - id: trailing-whitespace -- repo: local + - id: check-ast + - id: check-case-conflict + - id: check-docstring-first + - id: check-symlinks + - id: check-toml + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/codespell-project/codespell + rev: v2.2.6 hooks: - - id: codespell - name: codespell + - id: codespell description: Check for spelling errors - language: system - entry: codespell -- repo: local + + - repo: https://github.com/psf/black + rev: 22.3.0 hooks: - - id: black - name: black + - id: black description: Format Python code - language: system - entry: black - types_or: [python, pyi] -- repo: local + + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 hooks: - - id: isort - name: isort + - id: isort description: Group and sort Python imports - language: system - entry: isort - types_or: [python, pyi, cython] -- repo: local + + - repo: https://github.com/PyCQA/flake8 + rev: 7.0.0 hooks: - - id: flake8 - name: flake8 + - id: flake8 description: Check Python code for correctness, consistency and adherence to best practices - language: system - entry: flake8 --max-line-length=80 --ignore=E203,F811,I002,W503 - types: [python] -- repo: local - hooks: - - id: pylint - name: pylint - entry: pylint -rn -sn - language: system - types: [python] diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..63feff96 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,72 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [unreleased](https://github.com/joeloskarsson/neural-lam/compare/v0.1.0...HEAD) + +### Added + +- Replaced `constants.py` with `data_config.yaml` for data configuration management + [\#31](https://github.com/joeloskarsson/neural-lam/pull/31) + @sadamov + +- new metrics (`nll` and `crps_gauss`) and `metrics` submodule, stddiv output option + [c14b6b4](https://github.com/joeloskarsson/neural-lam/commit/c14b6b4323e6b56f1f18632b6ca8b0d65c3ce36a) + @joeloskarsson + +- ability to "watch" metrics and log + [c14b6b4](https://github.com/joeloskarsson/neural-lam/commit/c14b6b4323e6b56f1f18632b6ca8b0d65c3ce36a) + @joeloskarsson + +- pre-commit setup for linting and formatting + [\#6](https://github.com/joeloskarsson/neural-lam/pull/6), [\#8](https://github.com/joeloskarsson/neural-lam/pull/8) + @sadamov, @joeloskarsson + +### Changed + +- Updated scripts and modules to use `data_config.yaml` instead of `constants.py` + [\#31](https://github.com/joeloskarsson/neural-lam/pull/31) + @sadamov + +- Added new flags in `train_model.py` for configuration previously in `constants.py` + [\#31](https://github.com/joeloskarsson/neural-lam/pull/31) + @sadamov + +- moved batch-static features ("water cover") into forcing component return by `WeatherDataset` + [\#13](https://github.com/joeloskarsson/neural-lam/pull/13) + @joeloskarsson + +- change validation metric from `mae` to `rmse` + [c14b6b4](https://github.com/joeloskarsson/neural-lam/commit/c14b6b4323e6b56f1f18632b6ca8b0d65c3ce36a) + @joeloskarsson + +- change RMSE definition to compute sqrt after all averaging + [\#10](https://github.com/joeloskarsson/neural-lam/pull/10) + @joeloskarsson + +### Removed + +- `WeatherDataset(torch.Dataset)` no longer returns "batch-static" component of + training item (only `prev_state`, `target_state` and `forcing`), the batch static features are + instead included in forcing + [\#13](https://github.com/joeloskarsson/neural-lam/pull/13) + @joeloskarsson + +### Maintenance + +- simplify pre-commit setup by 1) reducing linting to only cover static + analysis excluding imports from external dependencies (this will be handled + in build/test cicd action introduced later), 2) pinning versions of linting + tools in pre-commit config (and remove from `requirements.txt`) and 3) using + github action to run pre-commit. + [\#29](https://github.com/mllam/neural-lam/pull/29) + @leifdenby + + +## [v0.1.0](https://github.com/joeloskarsson/neural-lam/releases/tag/v0.1.0) + +First tagged release of `neural-lam`, matching Oskarsson et al 2023 publication +() diff --git a/README.md b/README.md index 67d9d9b1..ba0bb3fe 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ Still, some restrictions are inevitable: ## A note on the limited area setting Currently we are using these models on a limited area covering the Nordic region, the so called MEPS area (see [paper](https://arxiv.org/abs/2309.17370)). There are still some parts of the code that is quite specific for the MEPS area use case. -This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants used (`neural_lam/constants.py`). +This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants set in a `data_config.yaml` file (path specified in `train_model.py --data_config` ). If there is interest to use Neural-LAM for other areas it is not a substantial undertaking to refactor the code to be fully area-agnostic. We would be happy to support such enhancements. See the issues https://github.com/joeloskarsson/neural-lam/issues/2, https://github.com/joeloskarsson/neural-lam/issues/3 and https://github.com/joeloskarsson/neural-lam/issues/4 for some initial ideas on how this could be done. @@ -104,13 +104,12 @@ The graph-related files are stored in a directory called `graphs`. ### Create remaining static features To create the remaining static files run the scripts `create_grid_features.py` and `create_parameter_weights.py`. -The main option to set for these is just which dataset to use. ## Weights & Biases Integration The project is fully integrated with [Weights & Biases](https://www.wandb.ai/) (W&B) for logging and visualization, but can just as easily be used without it. When W&B is used, training configuration, training/test statistics and plots are sent to the W&B servers and made available in an interactive web interface. If W&B is turned off, logging instead saves everything locally to a directory like `wandb/dryrun...`. -The W&B project name is set to `neural-lam`, but this can be changed in `neural_lam/constants.py`. +The W&B project name is set to `neural-lam`, but this can be changed in the flags of `train_model.py` (using argsparse). See the [W&B documentation](https://docs.wandb.ai/) for details. If you would like to login and use W&B, run: diff --git a/create_grid_features.py b/create_grid_features.py index c9038103..c3714368 100644 --- a/create_grid_features.py +++ b/create_grid_features.py @@ -6,6 +6,9 @@ import numpy as np import torch +# First-party +from neural_lam import config + def main(): """ @@ -13,14 +16,15 @@ def main(): """ parser = ArgumentParser(description="Training arguments") parser.add_argument( - "--dataset", + "--data_config", type=str, - default="meps_example", - help="Dataset to compute weights for (default: meps_example)", + default="neural_lam/data_config.yaml", + help="Path to data config file (default: neural_lam/data_config.yaml)", ) args = parser.parse_args() + config_loader = config.Config.from_file(args.data_config) - static_dir_path = os.path.join("data", args.dataset, "static") + static_dir_path = os.path.join("data", config_loader.dataset.name, "static") # -- Static grid node features -- grid_xy = torch.tensor( diff --git a/create_mesh.py b/create_mesh.py index cb524cd6..f04b4d4b 100644 --- a/create_mesh.py +++ b/create_mesh.py @@ -12,6 +12,9 @@ import torch_geometric as pyg from torch_geometric.utils.convert import from_networkx +# First-party +from neural_lam import config + def plot_graph(graph, title=None): fig, axis = plt.subplots(figsize=(8, 8), dpi=200) # W,H @@ -153,11 +156,10 @@ def prepend_node_index(graph, new_index): def main(): parser = ArgumentParser(description="Graph generation arguments") parser.add_argument( - "--dataset", + "--data_config", type=str, - default="meps_example", - help="Dataset to load grid point coordinates from " - "(default: meps_example)", + default="neural_lam/data_config.yaml", + help="Path to data config file (default: neural_lam/data_config.yaml)", ) parser.add_argument( "--graph", @@ -187,7 +189,8 @@ def main(): args = parser.parse_args() # Load grid positions - static_dir_path = os.path.join("data", args.dataset, "static") + config_loader = config.Config.from_file(args.data_config) + static_dir_path = os.path.join("data", config_loader.dataset.name, "static") graph_dir_path = os.path.join("graphs", args.graph) os.makedirs(graph_dir_path, exist_ok=True) diff --git a/create_parameter_weights.py b/create_parameter_weights.py index 494a5e81..cae1ae3e 100644 --- a/create_parameter_weights.py +++ b/create_parameter_weights.py @@ -8,7 +8,7 @@ from tqdm import tqdm # First-party -from neural_lam import constants +from neural_lam import config from neural_lam.weather_dataset import WeatherDataset @@ -18,10 +18,10 @@ def main(): """ parser = ArgumentParser(description="Training arguments") parser.add_argument( - "--dataset", + "--data_config", type=str, - default="meps_example", - help="Dataset to compute weights for (default: meps_example)", + default="neural_lam/data_config.yaml", + help="Path to data config file (default: neural_lam/data_config.yaml)", ) parser.add_argument( "--batch_size", @@ -43,7 +43,8 @@ def main(): ) args = parser.parse_args() - static_dir_path = os.path.join("data", args.dataset, "static") + config_loader = config.Config.from_file(args.data_config) + static_dir_path = os.path.join("data", config_loader.dataset.name, "static") # Create parameter weights based on height # based on fig A.1 in graph cast paper @@ -56,7 +57,10 @@ def main(): "500": 0.03, } w_list = np.array( - [w_dict[par.split("_")[-2]] for par in constants.PARAM_NAMES] + [ + w_dict[par.split("_")[-2]] + for par in config_loader.dataset.var_longnames + ] ) print("Saving parameter weights...") np.save( @@ -66,7 +70,7 @@ def main(): # Load dataset without any subsampling ds = WeatherDataset( - args.dataset, + config_loader.dataset.name, split="train", subsample_step=1, pred_length=63, @@ -113,7 +117,7 @@ def main(): # Compute mean and std.-dev. of one-step differences across the dataset print("Computing mean and std.-dev. for one-step differences...") ds_standard = WeatherDataset( - args.dataset, + config_loader.dataset.name, split="train", subsample_step=1, pred_length=63, diff --git a/neural_lam/config.py b/neural_lam/config.py new file mode 100644 index 00000000..5891ea74 --- /dev/null +++ b/neural_lam/config.py @@ -0,0 +1,62 @@ +# Standard library +import functools +from pathlib import Path + +# Third-party +import cartopy.crs as ccrs +import yaml + + +class Config: + """ + Class for loading configuration files. + + This class loads a configuration file and provides a way to access its + values as attributes. + """ + + def __init__(self, values): + self.values = values + + @classmethod + def from_file(cls, filepath): + """Load a configuration file.""" + if filepath.endswith(".yaml"): + with open(filepath, encoding="utf-8", mode="r") as file: + return cls(values=yaml.safe_load(file)) + else: + raise NotImplementedError(Path(filepath).suffix) + + def __getattr__(self, name): + keys = name.split(".") + value = self.values + for key in keys: + if key in value: + value = value[key] + else: + return None + if isinstance(value, dict): + return Config(values=value) + return value + + def __getitem__(self, key): + value = self.values[key] + if isinstance(value, dict): + return Config(values=value) + return value + + def __contains__(self, key): + return key in self.values + + def num_data_vars(self): + """Return the number of data variables for a given key.""" + return len(self.dataset.var_names) + + @functools.cached_property + def coords_projection(self): + """Return the projection.""" + proj_config = self.values["projection"] + proj_class_name = proj_config["class"] + proj_class = getattr(ccrs, proj_class_name) + proj_params = proj_config.get("kwargs", {}) + return proj_class(**proj_params) diff --git a/neural_lam/constants.py b/neural_lam/constants.py deleted file mode 100644 index 527c31d8..00000000 --- a/neural_lam/constants.py +++ /dev/null @@ -1,120 +0,0 @@ -# Third-party -import cartopy -import numpy as np - -WANDB_PROJECT = "neural-lam" - -SECONDS_IN_YEAR = ( - 365 * 24 * 60 * 60 -) # Assuming no leap years in dataset (2024 is next) - -# Log prediction error for these lead times -VAL_STEP_LOG_ERRORS = np.array([1, 2, 3, 5, 10, 15, 19]) - -# Log these metrics to wandb as scalar values for -# specific variables and lead times -# List of metrics to watch, including any prefix (e.g. val_rmse) -METRICS_WATCH = [] -# Dict with variables and lead times to log watched metrics for -# Format is a dictionary that maps from a variable index to -# a list of lead time steps -VAR_LEADS_METRICS_WATCH = { - 6: [2, 19], # t_2 - 14: [2, 19], # wvint_0 - 15: [2, 19], # z_1000 -} - -# Variable names -PARAM_NAMES = [ - "pres_heightAboveGround_0_instant", - "pres_heightAboveSea_0_instant", - "nlwrs_heightAboveGround_0_accum", - "nswrs_heightAboveGround_0_accum", - "r_heightAboveGround_2_instant", - "r_hybrid_65_instant", - "t_heightAboveGround_2_instant", - "t_hybrid_65_instant", - "t_isobaricInhPa_500_instant", - "t_isobaricInhPa_850_instant", - "u_hybrid_65_instant", - "u_isobaricInhPa_850_instant", - "v_hybrid_65_instant", - "v_isobaricInhPa_850_instant", - "wvint_entireAtmosphere_0_instant", - "z_isobaricInhPa_1000_instant", - "z_isobaricInhPa_500_instant", -] - -PARAM_NAMES_SHORT = [ - "pres_0g", - "pres_0s", - "nlwrs_0", - "nswrs_0", - "r_2", - "r_65", - "t_2", - "t_65", - "t_500", - "t_850", - "u_65", - "u_850", - "v_65", - "v_850", - "wvint_0", - "z_1000", - "z_500", -] -PARAM_UNITS = [ - "Pa", - "Pa", - "W/m\\textsuperscript{2}", - "W/m\\textsuperscript{2}", - "-", # unitless - "-", - "K", - "K", - "K", - "K", - "m/s", - "m/s", - "m/s", - "m/s", - "kg/m\\textsuperscript{2}", - "m\\textsuperscript{2}/s\\textsuperscript{2}", - "m\\textsuperscript{2}/s\\textsuperscript{2}", -] - -# Projection and grid -# Hard coded for now, but should eventually be part of dataset desc. files -GRID_SHAPE = (268, 238) # (y, x) - -LAMBERT_PROJ_PARAMS = { - "a": 6367470, - "b": 6367470, - "lat_0": 63.3, - "lat_1": 63.3, - "lat_2": 63.3, - "lon_0": 15.0, - "proj": "lcc", -} - -GRID_LIMITS = [ # In projection - -1059506.5523409774, # min x - 1310493.4476590226, # max x - -1331732.4471934352, # min y - 1338267.5528065648, # max y -] - -# Create projection -LAMBERT_PROJ = cartopy.crs.LambertConformal( - central_longitude=LAMBERT_PROJ_PARAMS["lon_0"], - central_latitude=LAMBERT_PROJ_PARAMS["lat_0"], - standard_parallels=( - LAMBERT_PROJ_PARAMS["lat_1"], - LAMBERT_PROJ_PARAMS["lat_2"], - ), -) - -# Data dimensions -GRID_FORCING_DIM = 5 * 3 + 1 # 5 feat. for 3 time-step window + 1 batch-static -GRID_STATE_DIM = 17 diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml new file mode 100644 index 00000000..f16a4a30 --- /dev/null +++ b/neural_lam/data_config.yaml @@ -0,0 +1,64 @@ +dataset: + name: meps_example + var_names: + - pres_0g + - pres_0s + - nlwrs_0 + - nswrs_0 + - r_2 + - r_65 + - t_2 + - t_65 + - t_500 + - t_850 + - u_65 + - u_850 + - v_65 + - v_850 + - wvint_0 + - z_1000 + - z_500 + var_units: + - Pa + - Pa + - r"$\mathrm{W}/\mathrm{m}^2$" + - r"$\mathrm{W}/\mathrm{m}^2$" + - "" + - "" + - K + - K + - K + - K + - m/s + - m/s + - m/s + - m/s + - r"$\mathrm{kg}/\mathrm{m}^2$" + - r"$\mathrm{m}^2/\mathrm{s}^2$" + - r"$\mathrm{m}^2/\mathrm{s}^2$" + var_longnames: + - pres_heightAboveGround_0_instant + - pres_heightAboveSea_0_instant + - nlwrs_heightAboveGround_0_accum + - nswrs_heightAboveGround_0_accum + - r_heightAboveGround_2_instant + - r_hybrid_65_instant + - t_heightAboveGround_2_instant + - t_hybrid_65_instant + - t_isobaricInhPa_500_instant + - t_isobaricInhPa_850_instant + - u_hybrid_65_instant + - u_isobaricInhPa_850_instant + - v_hybrid_65_instant + - v_isobaricInhPa_850_instant + - wvint_entireAtmosphere_0_instant + - z_isobaricInhPa_1000_instant + - z_isobaricInhPa_500_instant + num_forcing_features: 16 +grid_shape_state: [268, 238] +projection: + class: LambertConformal + kwargs: + central_longitude: 15.0 + central_latitude: 63.3 + standard_parallels: [63.3, 63.3] diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 6d160526..9cda9fc2 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -9,7 +9,7 @@ import wandb # First-party -from neural_lam import constants, metrics, utils, vis +from neural_lam import config, metrics, utils, vis class ARModel(pl.LightningModule): @@ -24,10 +24,13 @@ class ARModel(pl.LightningModule): def __init__(self, args): super().__init__() self.save_hyperparameters() - self.lr = args.lr + self.args = args + self.config_loader = config.Config.from_file(args.data_config) # Load static features for grid/data - static_data_dict = utils.load_static_data(args.dataset) + static_data_dict = utils.load_static_data( + self.config_loader.dataset.name + ) for static_data_name, static_data_tensor in static_data_dict.items(): self.register_buffer( static_data_name, static_data_tensor, persistent=False @@ -36,14 +39,11 @@ def __init__(self, args): # Double grid output dim. to also output std.-dev. self.output_std = bool(args.output_std) if self.output_std: - self.grid_output_dim = ( - 2 * constants.GRID_STATE_DIM - ) # Pred. dim. in grid cell + # Pred. dim. in grid cell + self.grid_output_dim = 2 * self.config_loader.num_data_vars() else: - self.grid_output_dim = ( - constants.GRID_STATE_DIM - ) # Pred. dim. in grid cell - + # Pred. dim. in grid cell + self.grid_output_dim = self.config_loader.num_data_vars() # Store constant per-variable std.-dev. weighting # Note that this is the inverse of the multiplicative weighting # in wMSE/wMAE @@ -57,11 +57,11 @@ def __init__(self, args): ( self.num_grid_nodes, grid_static_dim, - ) = self.grid_static_features.shape # 63784 = 268x238 + ) = self.grid_static_features.shape self.grid_dim = ( - 2 * constants.GRID_STATE_DIM + 2 * self.config_loader.num_data_vars() + grid_static_dim - + constants.GRID_FORCING_DIM + + self.config_loader.dataset.num_forcing_features ) # Instantiate loss function @@ -95,7 +95,7 @@ def __init__(self, args): def configure_optimizers(self): opt = torch.optim.AdamW( - self.parameters(), lr=self.lr, betas=(0.9, 0.95) + self.parameters(), lr=self.args.lr, betas=(0.9, 0.95) ) if self.opt_state: opt.load_state_dict(self.opt_state) @@ -246,7 +246,7 @@ def validation_step(self, batch, batch_idx): # Log loss per time step forward and mean val_log_dict = { f"val_loss_unroll{step}": time_step_loss[step - 1] - for step in constants.VAL_STEP_LOG_ERRORS + for step in self.args.val_steps_to_log } val_log_dict["val_mean_loss"] = mean_loss self.log_dict( @@ -294,7 +294,7 @@ def test_step(self, batch, batch_idx): # Log loss per time step forward and mean test_log_dict = { f"test_loss_unroll{step}": time_step_loss[step - 1] - for step in constants.VAL_STEP_LOG_ERRORS + for step in self.args.val_steps_to_log } test_log_dict["test_mean_loss"] = mean_loss @@ -328,7 +328,9 @@ def test_step(self, batch, batch_idx): spatial_loss = self.loss( prediction, target, pred_std, average_grid=False ) # (B, pred_steps, num_grid_nodes) - log_spatial_losses = spatial_loss[:, constants.VAL_STEP_LOG_ERRORS - 1] + log_spatial_losses = spatial_loss[ + :, [step - 1 for step in self.args.val_steps_to_log] + ] self.spatial_loss_maps.append(log_spatial_losses) # (B, N_log, num_grid_nodes) @@ -399,14 +401,15 @@ def plot_examples(self, batch, n_examples, prediction=None): pred_t[:, var_i], target_t[:, var_i], self.interior_mask[:, 0], + self.config_loader, title=f"{var_name} ({var_unit}), " f"t={t_i} ({self.step_length * t_i} h)", vrange=var_vrange, ) for var_i, (var_name, var_unit, var_vrange) in enumerate( zip( - constants.PARAM_NAMES_SHORT, - constants.PARAM_UNITS, + self.config_loader.dataset.var_names, + self.config_loader.dataset.var_units, var_vranges, ) ) @@ -417,7 +420,7 @@ def plot_examples(self, batch, n_examples, prediction=None): { f"{var_name}_example_{example_i}": wandb.Image(fig) for var_name, fig in zip( - constants.PARAM_NAMES_SHORT, var_figs + self.config_loader.dataset.var_names, var_figs ) } ) @@ -453,7 +456,7 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): """ log_dict = {} metric_fig = vis.plot_error_map( - metric_tensor, step_length=self.step_length + metric_tensor, self.config_loader, step_length=self.step_length ) full_log_name = f"{prefix}_{metric_name}" log_dict[full_log_name] = wandb.Image(metric_fig) @@ -471,14 +474,14 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): ) # Check if metrics are watched, log exact values for specific vars - if full_log_name in constants.METRICS_WATCH: - for var_i, timesteps in constants.VAR_LEADS_METRICS_WATCH.items(): - var = constants.PARAM_NAMES_SHORT[var_i] + if full_log_name in self.args.metrics_watch: + for var_i, timesteps in self.args.var_leads_metrics_watch.items(): + var = self.config_loader.dataset.var_nums[var_i] log_dict.update( { f"{full_log_name}_{var}_step_{step}": metric_tensor[ step - 1, var_i - ] # 1-indexed in constants + ] # 1-indexed in data_config for step in timesteps } ) @@ -542,10 +545,11 @@ def on_test_epoch_end(self): vis.plot_spatial_error( loss_map, self.interior_mask[:, 0], + self.config_loader, title=f"Test loss, t={t_i} ({self.step_length * t_i} h)", ) for t_i, loss_map in zip( - constants.VAL_STEP_LOG_ERRORS, mean_spatial_loss + self.args.val_steps_to_log, mean_spatial_loss ) ] @@ -555,14 +559,14 @@ def on_test_epoch_end(self): # also make without title and save as pdf pdf_loss_map_figs = [ - vis.plot_spatial_error(loss_map, self.interior_mask[:, 0]) + vis.plot_spatial_error( + loss_map, self.interior_mask[:, 0], self.config_loader + ) for loss_map in mean_spatial_loss ] pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps") os.makedirs(pdf_loss_maps_dir, exist_ok=True) - for t_i, fig in zip( - constants.VAL_STEP_LOG_ERRORS, pdf_loss_map_figs - ): + for t_i, fig in zip(self.args.val_steps_to_log, pdf_loss_map_figs): fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf")) # save mean spatial loss as .pt file also torch.save( diff --git a/neural_lam/models/base_hi_graph_model.py b/neural_lam/models/base_hi_graph_model.py index 9686d867..3cfdc009 100644 --- a/neural_lam/models/base_hi_graph_model.py +++ b/neural_lam/models/base_hi_graph_model.py @@ -182,9 +182,9 @@ def process_step(self, mesh_rep): ) # Update node and edge vectors in lists - mesh_rep_levels[level_l] = ( - new_node_rep # (B, num_mesh_nodes[l], d_h) - ) + mesh_rep_levels[ + level_l + ] = new_node_rep # (B, num_mesh_nodes[l], d_h) mesh_up_rep[level_l - 1] = new_edge_rep # (B, M_up[l-1], d_h) # - PROCESSOR - @@ -210,9 +210,9 @@ def process_step(self, mesh_rep): new_node_rep = gnn(send_node_rep, rec_node_rep, edge_rep) # Update node and edge vectors in lists - mesh_rep_levels[level_l] = ( - new_node_rep # (B, num_mesh_nodes[l], d_h) - ) + mesh_rep_levels[ + level_l + ] = new_node_rep # (B, num_mesh_nodes[l], d_h) # Return only bottom level representation return mesh_rep_levels[0] # (B, num_mesh_nodes[0], d_h) diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 943fc84e..d1602cfd 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -11,9 +11,6 @@ from torch import nn from tueplots import bundles, figsizes -# First-party -from neural_lam import constants - def load_dataset_stats(dataset_name, device="cpu"): """ @@ -267,7 +264,7 @@ def fractional_plot_bundle(fraction): return bundle -def init_wandb_metrics(wandb_logger): +def init_wandb_metrics(wandb_logger, val_steps): """ Set up wandb metrics to track """ @@ -316,7 +313,7 @@ def init_wandb(args): ) experiment = logger.experiment experiment.define_metric("val_mean_loss", summary="min") - for step in constants.VAL_STEP_LOG_ERRORS: + for step in val_steps: experiment.define_metric(f"val_loss_unroll{step}", summary="min") return logger diff --git a/neural_lam/vis.py b/neural_lam/vis.py index cef34a84..2b6abf15 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -4,11 +4,11 @@ import numpy as np # First-party -from neural_lam import constants, utils +from neural_lam import utils @matplotlib.rc_context(utils.fractional_plot_bundle(1)) -def plot_error_map(errors, title=None, step_length=3): +def plot_error_map(errors, data_config, title=None, step_length=3): """ Plot a heatmap of errors of different variables at different predictions horizons @@ -51,7 +51,7 @@ def plot_error_map(errors, title=None, step_length=3): y_ticklabels = [ f"{name} ({unit})" for name, unit in zip( - constants.PARAM_NAMES_SHORT, constants.PARAM_UNITS + data_config.dataset.var_names, data_config.dataset.var_units ) ] ax.set_yticklabels(y_ticklabels, rotation=30, size=label_size) @@ -63,7 +63,9 @@ def plot_error_map(errors, title=None, step_length=3): @matplotlib.rc_context(utils.fractional_plot_bundle(1)) -def plot_prediction(pred, target, obs_mask, title=None, vrange=None): +def plot_prediction( + pred, target, obs_mask, data_config, title=None, vrange=None +): """ Plot example prediction and grond truth. Each has shape (N_grid,) @@ -76,23 +78,25 @@ def plot_prediction(pred, target, obs_mask, title=None, vrange=None): vmin, vmax = vrange # Set up masking of border region - mask_reshaped = obs_mask.reshape(*constants.GRID_SHAPE) + mask_reshaped = obs_mask.reshape(*data_config.grid_shape_state) pixel_alpha = ( mask_reshaped.clamp(0.7, 1).cpu().numpy() ) # Faded border region fig, axes = plt.subplots( - 1, 2, figsize=(13, 7), subplot_kw={"projection": constants.LAMBERT_PROJ} + 1, + 2, + figsize=(13, 7), + subplot_kw={"projection": data_config.coords_projection()}, ) # Plot pred and target for ax, data in zip(axes, (target, pred)): ax.coastlines() # Add coastline outlines - data_grid = data.reshape(*constants.GRID_SHAPE).cpu().numpy() + data_grid = data.reshape(*data_config.grid_shape_state).cpu().numpy() im = ax.imshow( data_grid, origin="lower", - extent=constants.GRID_LIMITS, alpha=pixel_alpha, vmin=vmin, vmax=vmax, @@ -112,7 +116,7 @@ def plot_prediction(pred, target, obs_mask, title=None, vrange=None): @matplotlib.rc_context(utils.fractional_plot_bundle(1)) -def plot_spatial_error(error, obs_mask, title=None, vrange=None): +def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None): """ Plot errors over spatial map Error and obs_mask has shape (N_grid,) @@ -125,22 +129,22 @@ def plot_spatial_error(error, obs_mask, title=None, vrange=None): vmin, vmax = vrange # Set up masking of border region - mask_reshaped = obs_mask.reshape(*constants.GRID_SHAPE) + mask_reshaped = obs_mask.reshape(*data_config.grid_shape_state) pixel_alpha = ( mask_reshaped.clamp(0.7, 1).cpu().numpy() ) # Faded border region fig, ax = plt.subplots( - figsize=(5, 4.8), subplot_kw={"projection": constants.LAMBERT_PROJ} + figsize=(5, 4.8), + subplot_kw={"projection": data_config.coords_projection()}, ) ax.coastlines() # Add coastline outlines - error_grid = error.reshape(*constants.GRID_SHAPE).cpu().numpy() + error_grid = error.reshape(*data_config.grid_shape_state).cpu().numpy() im = ax.imshow( error_grid, origin="lower", - extent=constants.GRID_LIMITS, alpha=pixel_alpha, vmin=vmin, vmax=vmax, diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index eeefc313..a782806b 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -8,7 +8,7 @@ import torch # First-party -from neural_lam import constants, utils +from neural_lam import utils class WeatherDataset(torch.utils.data.Dataset): @@ -218,9 +218,11 @@ def __getitem__(self, idx): # can roll over to next year, ok because periodicity # Encode as sin/cos + # ! Make this more flexible in a separate create_forcings.py script + seconds_in_year = 365 * 24 * 3600 hour_angle = (hour_of_day / 12) * torch.pi # (sample_len,) year_angle = ( - (second_into_year / constants.SECONDS_IN_YEAR) * 2 * torch.pi + (second_into_year / seconds_in_year) * 2 * torch.pi ) # (sample_len,) datetime_forcing = torch.stack( ( diff --git a/plot_graph.py b/plot_graph.py index 48427d5c..90462194 100644 --- a/plot_graph.py +++ b/plot_graph.py @@ -7,7 +7,7 @@ import torch_geometric as pyg # First-party -from neural_lam import utils +from neural_lam import config, utils MESH_HEIGHT = 0.1 MESH_LEVEL_DIST = 0.2 @@ -20,10 +20,10 @@ def main(): """ parser = ArgumentParser(description="Plot graph") parser.add_argument( - "--dataset", + "--data_config", type=str, - default="meps_example", - help="Datast to load grid coordinates from (default: meps_example)", + default="neural_lam/data_config.yaml", + help="Path to data config file (default: neural_lam/data_config.yaml)", ) parser.add_argument( "--graph", @@ -44,14 +44,11 @@ def main(): ) args = parser.parse_args() + config_loader = config.Config.from_file(args.data_config) # Load graph data hierarchical, graph_ldict = utils.load_graph(args.graph) - ( - g2m_edge_index, - m2g_edge_index, - m2m_edge_index, - ) = ( + (g2m_edge_index, m2g_edge_index, m2m_edge_index,) = ( graph_ldict["g2m_edge_index"], graph_ldict["m2g_edge_index"], graph_ldict["m2m_edge_index"], @@ -62,7 +59,7 @@ def main(): ) mesh_static_features = graph_ldict["mesh_static_features"] - grid_static_features = utils.load_static_data(args.dataset)[ + grid_static_features = utils.load_static_data(config_loader.dataset.name)[ "grid_static_features" ] diff --git a/requirements.txt b/requirements.txt index 5a2111b2..f381d54f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,10 +10,6 @@ Cartopy>=0.22.0 pyproj>=3.4.1 tueplots>=0.0.8 plotly>=5.15.0 + # for dev -codespell>=2.0.0 -black>=21.9b0 -isort>=5.9.3 -flake8>=4.0.1 -pylint>=3.0.3 pre-commit>=2.15.0 diff --git a/train_model.py b/train_model.py index bca3f638..5a106f76 100644 --- a/train_model.py +++ b/train_model.py @@ -10,6 +10,7 @@ # First-party from neural_lam import utils +from neural_lam import config, utils from neural_lam.models.graph_lam import GraphLAM from neural_lam.models.hi_lam import HiLAM from neural_lam.models.hi_lam_parallel import HiLAMParallel @@ -29,14 +30,11 @@ def main(): parser = ArgumentParser( description="Train or evaluate NeurWP models for LAM" ) - - # General options parser.add_argument( - "--dataset", + "--data_config", type=str, - default="meps_example", - help="Dataset, corresponding to name in data directory " - "(default: meps_example)", + default="neural_lam/data_config.yaml", + help="Path to data config file (default: neural_lam/data_config.yaml)", ) parser.add_argument( "--model", @@ -186,8 +184,36 @@ def main(): help="Number of example predictions to plot during evaluation " "(default: 1)", ) + + # Logger Settings + parser.add_argument( + "--wandb_project", + type=str, + default="neural_lam", + help="Wandb project name (default: neural_lam)", + ) + parser.add_argument( + "--val_steps_to_log", + type=list, + default=[1, 2, 3, 5, 10, 15, 19], + help="Steps to log val loss for (default: [1, 2, 3, 5, 10, 15, 19])", + ) + parser.add_argument( + "--metrics_watch", + type=list, + default=[], + help="List of metrics to watch, including any prefix (e.g. val_rmse)", + ) + parser.add_argument( + "--var_leads_metrics_watch", + type=dict, + default={}, + help="Dict with variables and lead times to log watched metrics for", + ) args = parser.parse_args() + config_loader = config.Config.from_file(args.data_config) + # Asserts for arguments assert args.model in MODELS, f"Unknown model: {args.model}" assert args.step_length <= 3, "Too high step length" @@ -206,7 +232,7 @@ def main(): # Load data train_loader = torch.utils.data.DataLoader( WeatherDataset( - args.dataset, + config_loader.dataset.name, pred_length=args.ar_steps, split="train", subsample_step=args.step_length, @@ -220,7 +246,7 @@ def main(): max_pred_length = (65 // args.step_length) - 2 # 19 val_loader = torch.utils.data.DataLoader( WeatherDataset( - args.dataset, + config_loader.dataset.name, pred_length=max_pred_length, split="val", subsample_step=args.step_length, @@ -287,7 +313,7 @@ def main(): else: # Test eval_loader = torch.utils.data.DataLoader( WeatherDataset( - args.dataset, + config_loader.dataset.name, pred_length=max_pred_length, split="test", subsample_step=args.step_length,