Skip to content

Commit

Permalink
bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sadamov committed May 28, 2024
1 parent 5f538f9 commit 6685e94
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 52 deletions.
8 changes: 6 additions & 2 deletions neural_lam/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

# Standard library
import os

# Third-party
import cartopy.crs as ccrs
import numpy as np
import xarray as xr
Expand Down Expand Up @@ -86,7 +87,10 @@ def open_zarr(self, dataset_name):
"""Open a dataset specified by the dataset name."""
dataset_path = self.zarrs[dataset_name].path
if dataset_path is None or not os.path.exists(dataset_path):
print(f"Dataset '{dataset_name}' not found at path: {dataset_path}")
print(
f"Dataset '{dataset_name}' "
f"not found at path: {dataset_path}"
)
return None
dataset = xr.open_zarr(dataset_path, consolidated=True)
return dataset
Expand Down
80 changes: 37 additions & 43 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,19 @@
import numpy as np
import pytorch_lightning as pl
import torch

import wandb

# First-party
from neural_lam import metrics, vis
from neural_lam import config, metrics, vis


class ARModel(pl.LightningModule):
"""
Generic auto-regressive weather model.
Abstract class that can be extended.
Generic auto-regressive weather model. Abstract class that can be extended.
"""

# pylint: disable=arguments-differ
# Disable to override args/kwargs from superclass
# pylint: disable=arguments-differ Disable to override args/kwargs from
# superclass

def __init__(self, args):
super().__init__()
Expand Down Expand Up @@ -127,18 +125,18 @@ def expand_to_batch(x, batch_size):
def predict_step(self, prev_state, prev_prev_state, forcing):
"""
Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1
prev_state: (B, num_grid_nodes, feature_dim), X_t
prev_prev_state: (B, num_grid_nodes, feature_dim), X_{t-1}
forcing: (B, num_grid_nodes, forcing_dim)
prev_state: (B, num_grid_nodes, feature_dim), X_t prev_prev_state: (B,
num_grid_nodes, feature_dim), X_{t-1} forcing: (B, num_grid_nodes,
forcing_dim)
"""
raise NotImplementedError("No prediction step implemented")

def unroll_prediction(self, init_states, forcing_features, true_states):
"""
Roll out prediction taking multiple autoregressive steps with model
init_states: (B, 2, num_grid_nodes, d_f)
forcing_features: (B, pred_steps, num_grid_nodes, d_static_f)
true_states: (B, pred_steps, num_grid_nodes, d_f)
init_states: (B, 2, num_grid_nodes, d_f) forcing_features: (B,
pred_steps, num_grid_nodes, d_static_f) true_states: (B, pred_steps,
num_grid_nodes, d_f)
"""
prev_prev_state = init_states[:, 0]
prev_state = init_states[:, 1]
Expand All @@ -153,8 +151,8 @@ def unroll_prediction(self, init_states, forcing_features, true_states):
pred_state, pred_std = self.predict_step(
prev_state, prev_prev_state, forcing
)
# state: (B, num_grid_nodes, d_f)
# pred_std: (B, num_grid_nodes, d_f) or None
# state: (B, num_grid_nodes, d_f) pred_std: (B, num_grid_nodes,
# d_f) or None

# Overwrite border with true state
new_state = (
Expand Down Expand Up @@ -184,11 +182,10 @@ def unroll_prediction(self, init_states, forcing_features, true_states):

def common_step(self, batch):
"""
Predict on single batch
batch consists of:
init_states: (B, 2, num_grid_nodes, d_features)
target_states: (B, pred_steps, num_grid_nodes, d_features)
forcing_features: (B, pred_steps, num_grid_nodes, d_forcing),
Predict on single batch batch consists of: init_states: (B, 2,
num_grid_nodes, d_features) target_states: (B, pred_steps,
num_grid_nodes, d_features) forcing_features: (B, pred_steps,
num_grid_nodes, d_forcing),
where index 0 corresponds to index 1 of init_states
"""
(
Expand All @@ -200,8 +197,8 @@ def common_step(self, batch):
prediction, pred_std = self.unroll_prediction(
init_states, forcing_features, target_states
) # (B, pred_steps, num_grid_nodes, d_f)
# prediction: (B, pred_steps, num_grid_nodes, d_f)
# pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,)
# prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B,
# pred_steps, num_grid_nodes, d_f) or (d_f,)

return prediction, target_states, pred_std

Expand All @@ -214,9 +211,8 @@ def on_after_batch_transfer(self, batch, dataloader_idx):
forcing_features = (
forcing_features - self.forcing_mean
) / self.forcing_std
# boundary_features = (
# boundary_features - self.boundary_mean
# ) / self.boundary_std
# boundary_features = ( boundary_features - self.boundary_mean ) /
# self.boundary_std
batch = (
init_states,
target_states,
Expand Down Expand Up @@ -246,8 +242,8 @@ def training_step(self, batch):

def all_gather_cat(self, tensor_to_gather):
"""
Gather tensors across all ranks, and concatenate across dim. 0
(instead of stacking in new dim. 0)
Gather tensors across all ranks, and concatenate across dim. 0 (instead
of stacking in new dim. 0)
tensor_to_gather: (d1, d2, ...), distributed over K ranks
Expand Down Expand Up @@ -308,8 +304,8 @@ def test_step(self, batch, batch_idx):
Run test on single batch
"""
prediction, target, pred_std = self.common_step(batch)
# prediction: (B, pred_steps, num_grid_nodes, d_f)
# pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,)
# prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B,
# pred_steps, num_grid_nodes, d_f) or (d_f,)

time_step_loss = torch.mean(
self.loss(
Expand All @@ -330,10 +326,9 @@ def test_step(self, batch, batch_idx):
test_log_dict, on_step=False, on_epoch=True, sync_dist=True
)

# Compute all evaluation metrics for error maps
# Note: explicitly list metrics here, as test_metrics can contain
# additional ones, computed differently, but that should be aggregated
# on_test_epoch_end
# Compute all evaluation metrics for error maps Note: explicitly list
# metrics here, as test_metrics can contain additional ones, computed
# differently, but that should be aggregated on_test_epoch_end
for metric_name in ("mse", "mae"):
metric_func = metrics.get_metric(metric_name)
batch_metric_vals = metric_func(
Expand Down Expand Up @@ -378,9 +373,9 @@ def plot_examples(self, batch, n_examples, prediction=None):
"""
Plot the first n_examples forecasts from batch
batch: batch with data to plot corresponding forecasts for
n_examples: number of forecasts to plot
prediction: (B, pred_steps, num_grid_nodes, d_f), existing prediction.
batch: batch with data to plot corresponding forecasts for n_examples:
number of forecasts to plot prediction: (B, pred_steps, num_grid_nodes,
d_f), existing prediction.
Generate if None.
"""
if prediction is None:
Expand Down Expand Up @@ -470,15 +465,14 @@ def plot_examples(self, batch, n_examples, prediction=None):

def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
"""
Put together a dict with everything to log for one metric.
Also saves plots as pdf and csv if using test prefix.
Put together a dict with everything to log for one metric. Also saves
plots as pdf and csv if using test prefix.
metric_tensor: (pred_steps, d_f), metric values per time and variable
prefix: string, prefix to use for logging
metric_name: string, name of the metric
prefix: string, prefix to use for logging metric_name: string, name of
the metric
Return:
log_dict: dict with everything to log for given metric
Return: log_dict: dict with everything to log for given metric
"""
log_dict = {}
metric_fig = vis.plot_error_map(
Expand Down Expand Up @@ -552,8 +546,8 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix):

def on_test_epoch_end(self):
"""
Compute test metrics and make plots at the end of test epoch.
Will gather stored tensors and perform plotting and logging on rank 0.
Compute test metrics and make plots at the end of test epoch. Will
gather stored tensors and perform plotting and logging on rank 0.
"""
# Create error maps for all test metrics
self.aggregate_and_plot_metrics(self.test_metrics, prefix="test")
Expand Down
12 changes: 6 additions & 6 deletions neural_lam/models/base_hi_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,9 @@ def process_step(self, mesh_rep):
)

# Update node and edge vectors in lists
mesh_rep_levels[level_l] = (
new_node_rep # (B, num_mesh_nodes[l], d_h)
)
mesh_rep_levels[
level_l
] = new_node_rep # (B, num_mesh_nodes[l], d_h)
mesh_up_rep[level_l - 1] = new_edge_rep # (B, M_up[l-1], d_h)

# - PROCESSOR -
Expand All @@ -207,9 +207,9 @@ def process_step(self, mesh_rep):
new_node_rep = gnn(send_node_rep, rec_node_rep, edge_rep)

# Update node and edge vectors in lists
mesh_rep_levels[level_l] = (
new_node_rep # (B, num_mesh_nodes[l], d_h)
)
mesh_rep_levels[
level_l
] = new_node_rep # (B, num_mesh_nodes[l], d_h)

# Return only bottom level representation
return mesh_rep_levels[0] # (B, num_mesh_nodes[0], d_h)
Expand Down
2 changes: 1 addition & 1 deletion plot_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from trimesh.primitives import Box

# First-party
from neural_lam import utils
from neural_lam import config, utils

MESH_HEIGHT = 0.1
MESH_LEVEL_DIST = 0.05
Expand Down

0 comments on commit 6685e94

Please sign in to comment.