From 5f538f909854b61e332a5938905f686a72a88349 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 28 May 2024 15:43:10 +0200 Subject: [PATCH] cherry-pick with main --- .github/workflows/pre-commit.yml | 36 +++--- .pre-commit-config.yaml | 66 +++++------ CHANGELOG.md | 72 ++++++++++++ README.md | 5 +- create_mesh.py | 19 ++- neural_lam/config.py | 192 +++++++++++++++++++++++++++++++ neural_lam/models/ar_model.py | 5 +- neural_lam/utils.py | 189 ------------------------------ neural_lam/weather_dataset.py | 4 +- plot_graph.py | 18 +-- requirements.txt | 5 - 11 files changed, 329 insertions(+), 282 deletions(-) create mode 100644 CHANGELOG.md create mode 100644 neural_lam/config.py diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index a6ad84f1..dc519e5b 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -1,33 +1,25 @@ -name: Run pre-commit job +name: lint on: - push: + # trigger on pushes to any branch, but not main + push: + branches-ignore: + - main + # and also on PRs to main + pull_request: branches: - - main - pull_request: - branches: - - main + - main jobs: - pre-commit-job: + pre-commit-job: runs-on: ubuntu-latest - defaults: - run: - shell: bash -l {0} + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v2 - name: Set up Python uses: actions/setup-python@v2 with: - python-version: 3.9 - - name: Install pre-commit hooks - run: | - pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 \ - --index-url https://download.pytorch.org/whl/cpu - pip install -r requirements.txt - pip install pyg-lib==0.2.0 torch-scatter==2.1.1 torch-sparse==0.6.17 \ - torch-cluster==1.6.1 torch-geometric==2.3.1 \ - -f https://pytorch-geometric.com/whl/torch-2.0.1+cpu.html - - name: Run pre-commit hooks - run: | - pre-commit run --all-files + python-version: ${{ matrix.python-version }} + - uses: pre-commit/action@v2.0.3 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f48eca67..815a92e1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,51 +1,37 @@ repos: -- repo: https://github.com/pre-commit/pre-commit-hooks + - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: - - id: check-ast - - id: check-case-conflict - - id: check-docstring-first - - id: check-symlinks - - id: check-toml - - id: check-yaml - - id: debug-statements - - id: end-of-file-fixer - - id: trailing-whitespace -- repo: local + - id: check-ast + - id: check-case-conflict + - id: check-docstring-first + - id: check-symlinks + - id: check-toml + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/codespell-project/codespell + rev: v2.2.6 hooks: - - id: codespell - name: codespell + - id: codespell description: Check for spelling errors - language: system - entry: codespell -- repo: local + + - repo: https://github.com/psf/black + rev: 22.3.0 hooks: - - id: black - name: black + - id: black description: Format Python code - language: system - entry: black - types_or: [python, pyi] -- repo: local + + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 hooks: - - id: isort - name: isort + - id: isort description: Group and sort Python imports - language: system - entry: isort - types_or: [python, pyi, cython] -- repo: local + + - repo: https://github.com/PyCQA/flake8 + rev: 7.0.0 hooks: - - id: flake8 - name: flake8 + - id: flake8 description: Check Python code for correctness, consistency and adherence to best practices - language: system - entry: flake8 --max-line-length=80 --ignore=E203,F811,I002,W503 - types: [python] -- repo: local - hooks: - - id: pylint - name: pylint - entry: pylint -rn -sn - language: system - types: [python] diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..63feff96 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,72 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [unreleased](https://github.com/joeloskarsson/neural-lam/compare/v0.1.0...HEAD) + +### Added + +- Replaced `constants.py` with `data_config.yaml` for data configuration management + [\#31](https://github.com/joeloskarsson/neural-lam/pull/31) + @sadamov + +- new metrics (`nll` and `crps_gauss`) and `metrics` submodule, stddiv output option + [c14b6b4](https://github.com/joeloskarsson/neural-lam/commit/c14b6b4323e6b56f1f18632b6ca8b0d65c3ce36a) + @joeloskarsson + +- ability to "watch" metrics and log + [c14b6b4](https://github.com/joeloskarsson/neural-lam/commit/c14b6b4323e6b56f1f18632b6ca8b0d65c3ce36a) + @joeloskarsson + +- pre-commit setup for linting and formatting + [\#6](https://github.com/joeloskarsson/neural-lam/pull/6), [\#8](https://github.com/joeloskarsson/neural-lam/pull/8) + @sadamov, @joeloskarsson + +### Changed + +- Updated scripts and modules to use `data_config.yaml` instead of `constants.py` + [\#31](https://github.com/joeloskarsson/neural-lam/pull/31) + @sadamov + +- Added new flags in `train_model.py` for configuration previously in `constants.py` + [\#31](https://github.com/joeloskarsson/neural-lam/pull/31) + @sadamov + +- moved batch-static features ("water cover") into forcing component return by `WeatherDataset` + [\#13](https://github.com/joeloskarsson/neural-lam/pull/13) + @joeloskarsson + +- change validation metric from `mae` to `rmse` + [c14b6b4](https://github.com/joeloskarsson/neural-lam/commit/c14b6b4323e6b56f1f18632b6ca8b0d65c3ce36a) + @joeloskarsson + +- change RMSE definition to compute sqrt after all averaging + [\#10](https://github.com/joeloskarsson/neural-lam/pull/10) + @joeloskarsson + +### Removed + +- `WeatherDataset(torch.Dataset)` no longer returns "batch-static" component of + training item (only `prev_state`, `target_state` and `forcing`), the batch static features are + instead included in forcing + [\#13](https://github.com/joeloskarsson/neural-lam/pull/13) + @joeloskarsson + +### Maintenance + +- simplify pre-commit setup by 1) reducing linting to only cover static + analysis excluding imports from external dependencies (this will be handled + in build/test cicd action introduced later), 2) pinning versions of linting + tools in pre-commit config (and remove from `requirements.txt`) and 3) using + github action to run pre-commit. + [\#29](https://github.com/mllam/neural-lam/pull/29) + @leifdenby + + +## [v0.1.0](https://github.com/joeloskarsson/neural-lam/releases/tag/v0.1.0) + +First tagged release of `neural-lam`, matching Oskarsson et al 2023 publication +() diff --git a/README.md b/README.md index 67d9d9b1..ba0bb3fe 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ Still, some restrictions are inevitable: ## A note on the limited area setting Currently we are using these models on a limited area covering the Nordic region, the so called MEPS area (see [paper](https://arxiv.org/abs/2309.17370)). There are still some parts of the code that is quite specific for the MEPS area use case. -This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants used (`neural_lam/constants.py`). +This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants set in a `data_config.yaml` file (path specified in `train_model.py --data_config` ). If there is interest to use Neural-LAM for other areas it is not a substantial undertaking to refactor the code to be fully area-agnostic. We would be happy to support such enhancements. See the issues https://github.com/joeloskarsson/neural-lam/issues/2, https://github.com/joeloskarsson/neural-lam/issues/3 and https://github.com/joeloskarsson/neural-lam/issues/4 for some initial ideas on how this could be done. @@ -104,13 +104,12 @@ The graph-related files are stored in a directory called `graphs`. ### Create remaining static features To create the remaining static files run the scripts `create_grid_features.py` and `create_parameter_weights.py`. -The main option to set for these is just which dataset to use. ## Weights & Biases Integration The project is fully integrated with [Weights & Biases](https://www.wandb.ai/) (W&B) for logging and visualization, but can just as easily be used without it. When W&B is used, training configuration, training/test statistics and plots are sent to the W&B servers and made available in an interactive web interface. If W&B is turned off, logging instead saves everything locally to a directory like `wandb/dryrun...`. -The W&B project name is set to `neural-lam`, but this can be changed in `neural_lam/constants.py`. +The W&B project name is set to `neural-lam`, but this can be changed in the flags of `train_model.py` (using argsparse). See the [W&B documentation](https://docs.wandb.ai/) for details. If you would like to login and use W&B, run: diff --git a/create_mesh.py b/create_mesh.py index 2b6af9fd..da881594 100644 --- a/create_mesh.py +++ b/create_mesh.py @@ -13,9 +13,7 @@ from torch_geometric.utils.convert import from_networkx # First-party -from neural_lam import utils - -# matplotlib.use('TkAgg') +from neural_lam import config def plot_graph(graph, title=None): @@ -157,6 +155,12 @@ def prepend_node_index(graph, new_index): def main(): parser = ArgumentParser(description="Graph generation arguments") + parser.add_argument( + "--data_config", + type=str, + default="neural_lam/data_config.yaml", + help="Path to data config file (default: neural_lam/data_config.yaml)", + ) parser.add_argument( "--graph", type=str, @@ -182,20 +186,13 @@ def main(): default=0, help="Generate hierarchical mesh graph (default: 0, no)", ) - parser.add_argument( - "--data_config", - type=str, - default="neural_lam/data_config.yaml", - help="Path to data config file (default: neural_lam/data_config.yaml)", - ) - args = parser.parse_args() # Load grid positions graph_dir_path = os.path.join("graphs", args.graph) os.makedirs(graph_dir_path, exist_ok=True) - config_loader = utils.ConfigLoader(args.data_config) + config_loader = config.Config(args.data_config) xy = config_loader.get_nwp_xy() grid_xy = torch.tensor(xy) pos_max = torch.max(torch.abs(grid_xy)) diff --git a/neural_lam/config.py b/neural_lam/config.py new file mode 100644 index 00000000..819ce2aa --- /dev/null +++ b/neural_lam/config.py @@ -0,0 +1,192 @@ + +import os + +import cartopy.crs as ccrs +import numpy as np +import xarray as xr +import yaml + + +class Config: + """ + Class for loading configuration files. + + This class loads a YAML configuration file and provides a way to access + its values as attributes. + """ + + def __init__(self, config_path, values=None): + self.config_path = config_path + if values is None: + self.values = self.load_config() + else: + self.values = values + + def load_config(self): + """Load configuration file.""" + with open(self.config_path, encoding="utf-8", mode="r") as file: + return yaml.safe_load(file) + + def __getattr__(self, name): + keys = name.split(".") + value = self.values + for key in keys: + if key in value: + value = value[key] + else: + return None + if isinstance(value, dict): + return Config(None, values=value) + return value + + def __getitem__(self, key): + value = self.values[key] + if isinstance(value, dict): + return Config(None, values=value) + return value + + def __contains__(self, key): + return key in self.values + + def param_names(self): + """Return parameter names.""" + surface_names = self.values["state"]["surface"] + atmosphere_names = [ + f"{var}_{level}" + for var in self.values["state"]["atmosphere"] + for level in self.values["state"]["levels"] + ] + return surface_names + atmosphere_names + + def param_units(self): + """Return parameter units.""" + surface_units = self.values["state"]["surface_units"] + atmosphere_units = [ + unit + for unit in self.values["state"]["atmosphere_units"] + for _ in self.values["state"]["levels"] + ] + return surface_units + atmosphere_units + + def num_data_vars(self, key): + """Return the number of data variables for a given key.""" + surface_vars = len(self.values[key]["surface"]) + atmosphere_vars = len(self.values[key]["atmosphere"]) + levels = len(self.values[key]["levels"]) + return surface_vars + atmosphere_vars * levels + + def projection(self): + """Return the projection.""" + proj_config = self.values["projections"]["class"] + proj_class = getattr(ccrs, proj_config["proj_class"]) + proj_params = proj_config["proj_params"] + return proj_class(**proj_params) + + def open_zarr(self, dataset_name): + """Open a dataset specified by the dataset name.""" + dataset_path = self.zarrs[dataset_name].path + if dataset_path is None or not os.path.exists(dataset_path): + print(f"Dataset '{dataset_name}' not found at path: {dataset_path}") + return None + dataset = xr.open_zarr(dataset_path, consolidated=True) + return dataset + + def load_normalization_stats(self): + """Load normalization statistics from Zarr archive.""" + normalization_path = self.normalization.zarr + if not os.path.exists(normalization_path): + print( + f"Normalization statistics not found at " + f"path: {normalization_path}" + ) + return None + normalization_stats = xr.open_zarr( + normalization_path, consolidated=True + ) + return normalization_stats + + def process_dataset(self, dataset_name, split="train", stack=True): + """Process a single dataset specified by the dataset name.""" + + dataset = self.open_zarr(dataset_name) + if dataset is None: + return None + + start, end = ( + self.splits[split].start, + self.splits[split].end, + ) + dataset = dataset.sel(time=slice(start, end)) + dataset = dataset.rename_dims( + { + v: k + for k, v in self.zarrs[dataset_name].dims.values.items() + if k not in dataset.dims + } + ) + + vars_surface = [] + if self[dataset_name].surface: + vars_surface = dataset[self[dataset_name].surface] + + vars_atmosphere = [] + if self[dataset_name].atmosphere: + vars_atmosphere = xr.merge( + [ + dataset[var] + .sel(level=level, drop=True) + .rename(f"{var}_{level}") + for var in self[dataset_name].atmosphere + for level in self[dataset_name].levels + ] + ) + + if vars_surface and vars_atmosphere: + dataset = xr.merge([vars_surface, vars_atmosphere]) + elif vars_surface: + dataset = vars_surface + elif vars_atmosphere: + dataset = vars_atmosphere + else: + print(f"No variables found in dataset {dataset_name}") + return None + + if not all( + lat_lon in self.zarrs[dataset_name].dims.values.values() + for lat_lon in self.zarrs[ + dataset_name + ].lat_lon_names.values.values() + ): + lat_name = self.zarrs[dataset_name].lat_lon_names.lat + lon_name = self.zarrs[dataset_name].lat_lon_names.lon + if dataset[lat_name].ndim == 2: + dataset[lat_name] = dataset[lat_name].isel(x=0, drop=True) + if dataset[lon_name].ndim == 2: + dataset[lon_name] = dataset[lon_name].isel(y=0, drop=True) + dataset = dataset.assign_coords( + x=dataset[lon_name], y=dataset[lat_name] + ) + + if stack: + dataset = self.stack_grid(dataset) + + return dataset + + def stack_grid(self, dataset): + """Stack grid dimensions.""" + dataset = dataset.squeeze().stack(grid=("x", "y")).to_array() + + if "time" in dataset.dims: + dataset = dataset.transpose("time", "grid", "variable") + else: + dataset = dataset.transpose("grid", "variable") + return dataset + + def get_nwp_xy(self): + """Get the x and y coordinates for the NWP grid.""" + x = self.process_dataset("static", stack=False).x.values + y = self.process_dataset("static", stack=False).y.values + xx, yy = np.meshgrid(y, x) + xy = np.stack((xx, yy), axis=0) + + return xy diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index f49eb094..fff28632 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -6,10 +6,11 @@ import numpy as np import pytorch_lightning as pl import torch + import wandb # First-party -from neural_lam import metrics, utils, vis +from neural_lam import metrics, vis class ARModel(pl.LightningModule): @@ -25,7 +26,7 @@ def __init__(self, args): super().__init__() self.save_hyperparameters() self.args = args - self.config_loader = utils.ConfigLoader(args.data_config) + self.config_loader = config.Config(args.data_config) # Load static features for grid/data static = self.config_loader.process_dataset("static") diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 96e1549e..18584d2e 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -2,11 +2,7 @@ import os # Third-party -import cartopy.crs as ccrs -import numpy as np import torch -import xarray as xr -import yaml from torch import nn from tueplots import bundles, figsizes @@ -197,188 +193,3 @@ def init_wandb_metrics(wandb_logger, val_steps): experiment.define_metric("val_mean_loss", summary="min") for step in val_steps: experiment.define_metric(f"val_loss_unroll{step}", summary="min") - - -class ConfigLoader: - """ - Class for loading configuration files. - - This class loads a YAML configuration file and provides a way to access - its values as attributes. - """ - - def __init__(self, config_path, values=None): - self.config_path = config_path - if values is None: - self.values = self.load_config() - else: - self.values = values - - def load_config(self): - """Load configuration file.""" - with open(self.config_path, encoding="utf-8", mode="r") as file: - return yaml.safe_load(file) - - def __getattr__(self, name): - keys = name.split(".") - value = self.values - for key in keys: - if key in value: - value = value[key] - else: - return None - if isinstance(value, dict): - return ConfigLoader(None, values=value) - return value - - def __getitem__(self, key): - value = self.values[key] - if isinstance(value, dict): - return ConfigLoader(None, values=value) - return value - - def __contains__(self, key): - return key in self.values - - def param_names(self): - """Return parameter names.""" - surface_names = self.values["state"]["surface"] - atmosphere_names = [ - f"{var}_{level}" - for var in self.values["state"]["atmosphere"] - for level in self.values["state"]["levels"] - ] - return surface_names + atmosphere_names - - def param_units(self): - """Return parameter units.""" - surface_units = self.values["state"]["surface_units"] - atmosphere_units = [ - unit - for unit in self.values["state"]["atmosphere_units"] - for _ in self.values["state"]["levels"] - ] - return surface_units + atmosphere_units - - def num_data_vars(self, key): - """Return the number of data variables for a given key.""" - surface_vars = len(self.values[key]["surface"]) - atmosphere_vars = len(self.values[key]["atmosphere"]) - levels = len(self.values[key]["levels"]) - return surface_vars + atmosphere_vars * levels - - def projection(self): - """Return the projection.""" - proj_config = self.values["projections"]["class"] - proj_class = getattr(ccrs, proj_config["proj_class"]) - proj_params = proj_config["proj_params"] - return proj_class(**proj_params) - - def open_zarr(self, dataset_name): - """Open a dataset specified by the dataset name.""" - dataset_path = self.zarrs[dataset_name].path - if dataset_path is None or not os.path.exists(dataset_path): - print(f"Dataset '{dataset_name}' not found at path: {dataset_path}") - return None - dataset = xr.open_zarr(dataset_path, consolidated=True) - return dataset - - def load_normalization_stats(self): - """Load normalization statistics from Zarr archive.""" - normalization_path = self.normalization.zarr - if not os.path.exists(normalization_path): - print( - f"Normalization statistics not found at " - f"path: {normalization_path}" - ) - return None - normalization_stats = xr.open_zarr( - normalization_path, consolidated=True - ) - return normalization_stats - - def process_dataset(self, dataset_name, split="train", stack=True): - """Process a single dataset specified by the dataset name.""" - - dataset = self.open_zarr(dataset_name) - if dataset is None: - return None - - start, end = ( - self.splits[split].start, - self.splits[split].end, - ) - dataset = dataset.sel(time=slice(start, end)) - dataset = dataset.rename_dims( - { - v: k - for k, v in self.zarrs[dataset_name].dims.values.items() - if k not in dataset.dims - } - ) - - vars_surface = [] - if self[dataset_name].surface: - vars_surface = dataset[self[dataset_name].surface] - - vars_atmosphere = [] - if self[dataset_name].atmosphere: - vars_atmosphere = xr.merge( - [ - dataset[var] - .sel(level=level, drop=True) - .rename(f"{var}_{level}") - for var in self[dataset_name].atmosphere - for level in self[dataset_name].levels - ] - ) - - if vars_surface and vars_atmosphere: - dataset = xr.merge([vars_surface, vars_atmosphere]) - elif vars_surface: - dataset = vars_surface - elif vars_atmosphere: - dataset = vars_atmosphere - else: - print(f"No variables found in dataset {dataset_name}") - return None - - if not all( - lat_lon in self.zarrs[dataset_name].dims.values.values() - for lat_lon in self.zarrs[ - dataset_name - ].lat_lon_names.values.values() - ): - lat_name = self.zarrs[dataset_name].lat_lon_names.lat - lon_name = self.zarrs[dataset_name].lat_lon_names.lon - if dataset[lat_name].ndim == 2: - dataset[lat_name] = dataset[lat_name].isel(x=0, drop=True) - if dataset[lon_name].ndim == 2: - dataset[lon_name] = dataset[lon_name].isel(y=0, drop=True) - dataset = dataset.assign_coords( - x=dataset[lon_name], y=dataset[lat_name] - ) - - if stack: - dataset = self.stack_grid(dataset) - - return dataset - - def stack_grid(self, dataset): - """Stack grid dimensions.""" - dataset = dataset.squeeze().stack(grid=("x", "y")).to_array() - - if "time" in dataset.dims: - dataset = dataset.transpose("time", "grid", "variable") - else: - dataset = dataset.transpose("grid", "variable") - return dataset - - def get_nwp_xy(self): - """Get the x and y coordinates for the NWP grid.""" - x = self.process_dataset("static", stack=False).x.values - y = self.process_dataset("static", stack=False).y.values - xx, yy = np.meshgrid(y, x) - xy = np.stack((xx, yy), axis=0) - - return xy diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 4b5da0a8..6ce630c7 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -3,7 +3,7 @@ import torch # First-party -from neural_lam import utils +from neural_lam import config class WeatherDataset(torch.utils.data.Dataset): @@ -35,7 +35,7 @@ def __init__( self.batch_size = batch_size self.ar_steps = ar_steps self.control_only = control_only - self.config_loader = utils.ConfigLoader(data_config) + self.config_loader = config.Config(data_config) self.state = self.config_loader.process_dataset("state", self.split) assert self.state is not None, "State dataset not found" diff --git a/plot_graph.py b/plot_graph.py index 50c54e06..9b465fd4 100644 --- a/plot_graph.py +++ b/plot_graph.py @@ -76,8 +76,16 @@ def export_to_3d_model(node_pos, edge_plot_list, filename): def main(): - """Plot the graph.""" + """ + Plot graph structure in 3D using plotly + """ parser = ArgumentParser(description="Plot graph") + parser.add_argument( + "--data_config", + type=str, + default="neural_lam/data_config.yaml", + help="Path to data config file (default: neural_lam/data_config.yaml)", + ) parser.add_argument( "--graph", type=str, @@ -95,12 +103,6 @@ def main(): default=0, help="If the axis should be displayed (default: 0 (No))", ) - parser.add_argument( - "--data_config", - type=str, - default="neural_lam/data_config.yaml", - help="Path to data config file (default: neural_lam/data_config.yaml)", - ) parser.add_argument( "--export", type=str, @@ -121,7 +123,7 @@ def main(): ) mesh_static_features = graph_ldict["mesh_static_features"] - config_loader = utils.ConfigLoader(args.data_config) + config_loader = config.Config(args.data_config) xy = config_loader.get_nwp_xy() grid_xy = xy.transpose(1, 2, 0).reshape(-1, 2) pos_max = np.max(np.abs(grid_xy)) diff --git a/requirements.txt b/requirements.txt index cb9bd425..70b97330 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,9 +14,4 @@ xarray>=0.20.1 zarr>=2.10.0 dask>=2022.0.0 # for dev -codespell>=2.0.0 -black>=21.9b0 -isort>=5.9.3 -flake8>=4.0.1 -pylint>=3.0.3 pre-commit>=2.15.0