Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into feature_dataset_yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
sadamov committed May 31, 2024
2 parents 6423fdf + 9d558d1 commit 59c4947
Show file tree
Hide file tree
Showing 12 changed files with 263 additions and 47 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ graphs
*.sif
sweeps
test_*.sh
.vscode
cosmo_hilam.html
normalization.zarr

Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Robust restoration of optimizer and scheduler using `ckpt_path`
[\#17](https://github.com/mllam/neural-lam/pull/17)
@sadamov

- Updated scripts and modules to use `data_config.yaml` instead of `constants.py`
[\#31](https://github.com/joeloskarsson/neural-lam/pull/31)
@sadamov
Expand Down
67 changes: 67 additions & 0 deletions create_grid_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Standard library
import os
from argparse import ArgumentParser

# Third-party
import numpy as np
import torch

# First-party
from neural_lam import config


def main():
"""
Pre-compute all static features related to the grid nodes
"""
parser = ArgumentParser(description="Training arguments")
parser.add_argument(
"--data_config",
type=str,
default="neural_lam/data_config.yaml",
help="Path to data config file (default: neural_lam/data_config.yaml)",
)
args = parser.parse_args()
config_loader = config.Config.from_file(args.data_config)

static_dir_path = os.path.join(
"data", config_loader.dataset.name, "static"
)

# -- Static grid node features --
grid_xy = torch.tensor(
np.load(os.path.join(static_dir_path, "nwp_xy.npy"))
) # (2, N_x, N_y)
grid_xy = grid_xy.flatten(1, 2).T # (N_grid, 2)
pos_max = torch.max(torch.abs(grid_xy))
grid_xy = grid_xy / pos_max # Divide by maximum coordinate

geopotential = torch.tensor(
np.load(os.path.join(static_dir_path, "surface_geopotential.npy"))
) # (N_x, N_y)
geopotential = geopotential.flatten(0, 1).unsqueeze(1) # (N_grid,1)
gp_min = torch.min(geopotential)
gp_max = torch.max(geopotential)
# Rescale geopotential to [0,1]
geopotential = (geopotential - gp_min) / (gp_max - gp_min) # (N_grid, 1)

grid_border_mask = torch.tensor(
np.load(os.path.join(static_dir_path, "border_mask.npy")),
dtype=torch.int64,
) # (N_x, N_y)
grid_border_mask = (
grid_border_mask.flatten(0, 1).to(torch.float).unsqueeze(1)
) # (N_grid, 1)

# Concatenate grid features
grid_features = torch.cat(
(grid_xy, geopotential, grid_border_mask), dim=1
) # (N_grid, 4)

torch.save(
grid_features, os.path.join(static_dir_path, "grid_features.pt")
)


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions create_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def prepend_node_index(graph, new_index):
def main():
parser = ArgumentParser(description="Graph generation arguments")
parser.add_argument(
"--data_config",
"--data_config",
type=str,
default="neural_lam/data_config.yaml",
Expand Down Expand Up @@ -193,6 +194,7 @@ def main():
args = parser.parse_args()

# Load grid positions
config_loader = config.Config.from_file(args.data_config)
graph_dir_path = os.path.join("graphs", args.graph)
os.makedirs(graph_dir_path, exist_ok=True)

Expand Down
48 changes: 47 additions & 1 deletion create_parameter_weights.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,29 @@
# Standard library
import os
from argparse import ArgumentParser

# Third-party
import numpy as np
import torch
import xarray as xr
from tqdm import tqdm

# First-party
from neural_lam.weather_dataset import WeatherDataModule
from neural_lam import config
from neural_lam.weather_dataset import WeatherDataModule, WeatherDataset


def main():
"""
Pre-compute parameter weights to be used in loss function
"""
parser = ArgumentParser(description="Training arguments")
parser.add_argument(
"--data_config",
type=str,
default="neural_lam/data_config.yaml",
help="Path to data config file (default: neural_lam/data_config.yaml)",
)
parser.add_argument(
"--batch_size",
type=int,
Expand All @@ -36,13 +45,50 @@ def main():

args = parser.parse_args()

config_loader = config.Config.from_file(args.data_config)
static_dir_path = os.path.join(
"data", config_loader.dataset.name, "static"
)

# Create parameter weights based on height
# based on fig A.1 in graph cast paper
w_dict = {
"2": 1.0,
"0": 0.1,
"65": 0.065,
"1000": 0.1,
"850": 0.05,
"500": 0.03,
}
w_list = np.array(
[
w_dict[par.split("_")[-2]]
for par in config_loader.dataset.var_longnames
]
)
print("Saving parameter weights...")
np.save(
os.path.join(static_dir_path, "parameter_weights.npy"),
w_list.astype("float32"),
)
data_module = WeatherDataModule(
batch_size=args.batch_size, num_workers=args.num_workers
)
data_module.setup()
loader = data_module.train_dataloader()

# Load dataset without any subsampling
ds = WeatherDataset(
config_loader.dataset.name,
split="train",
subsample_step=1,
pred_length=63,
standardize=False,
) # Without standardization
loader = torch.utils.data.DataLoader(
ds, args.batch_size, shuffle=False, num_workers=args.n_workers
)
# Compute mean and std.-dev. of each parameter (+ flux forcing)
# Compute mean and std.-dev. of each parameter (+ forcing forcing)
# across full dataset
print("Computing mean and std.-dev. for parameters...")
Expand Down
39 changes: 25 additions & 14 deletions neural_lam/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Standard library
import functools
import os
from pathlib import Path

# Third-party
import cartopy.crs as ccrs
Expand All @@ -12,21 +14,21 @@ class Config:
"""
Class for loading configuration files.
This class loads a YAML configuration file and provides a way to access
its values as attributes.
This class loads a configuration file and provides a way to access its
values as attributes.
"""

def __init__(self, config_path, values=None):
self.config_path = config_path
if values is None:
self.values = self.load_config()
else:
self.values = values
def __init__(self, values):
self.values = values

def load_config(self):
"""Load configuration file."""
with open(self.config_path, encoding="utf-8", mode="r") as file:
return yaml.safe_load(file)
@classmethod
def from_file(cls, filepath):
"""Load a configuration file."""
if filepath.endswith(".yaml"):
with open(filepath, encoding="utf-8", mode="r") as file:
return cls(values=yaml.safe_load(file))
else:
raise NotImplementedError(Path(filepath).suffix)

def __getattr__(self, name):
keys = name.split(".")
Expand All @@ -37,18 +39,27 @@ def __getattr__(self, name):
else:
return None
if isinstance(value, dict):
return Config(None, values=value)
return Config(values=value)
return value

def __getitem__(self, key):
value = self.values[key]
if isinstance(value, dict):
return Config(None, values=value)
return Config(values=value)
return value

def __contains__(self, key):
return key in self.values

@functools.cached_property
def coords_projection(self):
"""Return the projection."""
proj_config = self.values["projection"]
proj_class_name = proj_config["class"]
proj_class = getattr(ccrs, proj_class_name)
proj_params = proj_config.get("kwargs", {})
return proj_class(**proj_params)

def param_names(self):
"""Return parameter names."""
surface_names = self.values["state"]["surface"]
Expand Down
26 changes: 15 additions & 11 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, args):
super().__init__()
self.save_hyperparameters()
self.args = args
self.config_loader = config.Config(args.data_config)
self.config_loader = config.Config.from_file(args.data_config)

# Load static features for grid/data
static = self.config_loader.process_dataset("static")
Expand Down Expand Up @@ -78,8 +78,8 @@ def __init__(self, args):
if self.output_std:
self.test_metrics["output_std"] = [] # Treat as metric

# For making restoring of optimizer state optional (slight hack)
self.opt_state = None
# For making restoring of optimizer state optional
self.restore_opt = args.restore_opt

# For example plotting
self.n_example_pred = args.n_example_pred
Expand Down Expand Up @@ -107,9 +107,6 @@ def configure_optimizers(self):
opt = torch.optim.AdamW(
self.parameters(), lr=self.args.lr, betas=(0.9, 0.95)
)
if self.opt_state:
opt.load_state_dict(self.opt_state)

return opt

@property
Expand Down Expand Up @@ -278,7 +275,7 @@ def validation_step(self, batch, batch_idx):
# Log loss per time step forward and mean
val_log_dict = {
f"val_loss_unroll{step}": time_step_loss[step - 1]
for step in self.args.val_steps_log
for step in self.args.val_steps_to_log
}
val_log_dict["val_mean_loss"] = mean_loss
self.log_dict(
Expand Down Expand Up @@ -326,7 +323,7 @@ def test_step(self, batch, batch_idx):
# Log loss per time step forward and mean
test_log_dict = {
f"test_loss_unroll{step}": time_step_loss[step - 1]
for step in self.args.val_steps_log
for step in self.args.val_steps_to_log
}
test_log_dict["test_mean_loss"] = mean_loss

Expand Down Expand Up @@ -359,7 +356,9 @@ def test_step(self, batch, batch_idx):
spatial_loss = self.loss(
prediction, target, pred_std, average_grid=False
) # (B, pred_steps, num_grid_nodes)
log_spatial_losses = spatial_loss[:, self.args.val_steps_log - 1]
log_spatial_losses = spatial_loss[
:, [step - 1 for step in self.args.val_steps_to_log]
]
self.spatial_loss_maps.append(log_spatial_losses)
# (B, N_log, num_grid_nodes)

Expand Down Expand Up @@ -438,6 +437,8 @@ def plot_examples(self, batch, n_examples, prediction=None):
)
for var_i, (var_name, var_unit, var_vrange) in enumerate(
zip(
self.config_loader.dataset.var_names,
self.config_loader.dataset.var_units,
self.config_loader.param_names(),
self.config_loader.param_units(),
var_vranges,
Expand Down Expand Up @@ -578,7 +579,7 @@ def on_test_epoch_end(self):
title=f"Test loss, t={t_i} ({self.step_length * t_i} h)",
)
for t_i, loss_map in zip(
self.args.val_steps_log, mean_spatial_loss
self.args.val_steps_to_log, mean_spatial_loss
)
]

Expand All @@ -597,7 +598,7 @@ def on_test_epoch_end(self):
wandb.run.dir, "spatial_loss_maps"
)
os.makedirs(pdf_loss_maps_dir, exist_ok=True)
for t_i, fig in zip(self.args.val_steps_log, pdf_loss_map_figs):
for t_i, fig in zip(self.args.val_steps_to_log, pdf_loss_map_figs):
fig.savefig(
os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf")
)
Expand Down Expand Up @@ -630,3 +631,6 @@ def on_load_checkpoint(self, checkpoint):
)
loaded_state_dict[new_key] = loaded_state_dict[old_key]
del loaded_state_dict[old_key]
if not self.restore_opt:
opt = self.configure_optimizers()
checkpoint["optimizer_states"] = [opt.state_dict()]
2 changes: 1 addition & 1 deletion neural_lam/models/base_hi_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, args):
if level_index < (self.num_levels - 1):
up_edges = self.mesh_up_features[level_index].shape[0]
down_edges = self.mesh_down_features[level_index].shape[0]
print(f" {level_index}<->{level_index+1}")
print(f" {level_index}<->{level_index + 1}")
print(f" - {up_edges} up edges, {down_edges} down edges")
# Embedders
# Assume all levels have same static feature dimensionality
Expand Down
Loading

0 comments on commit 59c4947

Please sign in to comment.