diff --git a/README.md b/README.md index e21b7c24..7a5e5caf 100644 --- a/README.md +++ b/README.md @@ -108,7 +108,9 @@ Once `neural-lam` is installed you will be able to train/evaluate models. For th interface that provides the data in a data-structure that can be used within neural-lam. A datastore is used to create a `pytorch.Dataset`-derived class that samples the data in time to create individual samples for - training, validation and testing. + training, validation and testing. A secondary datastore can be provided + for the boundary data. Currently, boundary datastore must be of type `mdp` + and only contain forcing features. This can easily be expanded in the future. 2. **The graph structure** is used to define message-passing GNN layers, that are trained to emulate fluid flow in the atmosphere over time. The @@ -121,7 +123,7 @@ different aspects about the training and evaluation of the model. The path you provide to the neural-lam config (`config.yaml`) also sets the root directory relative to which all other paths are resolved, as in the parent -directory of the config becomes the root directory. Both the datastore and +directory of the config becomes the root directory. Both the datastores and graphs you generate are then stored in subdirectories of this root directory. Exactly how and where a specific datastore expects its source data to be stored and where it stores its derived data is up to the implementation of the @@ -134,6 +136,7 @@ assume you placed `config.yaml` in a folder called `data`): data/ ├── config.yaml - Configuration file for neural-lam ├── danra.datastore.yaml - Configuration file for the datastore, referred to from config.yaml +├── era5.datastore.zarr/ - Optional configuration file for the boundary datastore, referred to from config.yaml └── graphs/ - Directory containing graphs for training ``` @@ -142,18 +145,20 @@ And the content of `config.yaml` could in this case look like: datastore: kind: mdp config_path: danra.datastore.yaml +datastore_boundary: + kind: mdp + config_path: era5.datastore.yaml training: state_feature_weighting: __config_class__: ManualStateFeatureWeighting - values: + weights: u100m: 1.0 v100m: 1.0 ``` -For now the neural-lam config only defines two things: 1) the kind of data -store and the path to its config, and 2) the weighting of different features in -the loss function. If you don't define the state feature weighting it will default -to weighting all features equally. +For now the neural-lam config only defines two things: +1) the kind of datastores and the path to their config +2) the weighting of different features in the loss function. If you don't define the state feature weighting it will default to weighting all features equally. (This example is taken from the `tests/datastore_examples/mdp` directory.) @@ -525,5 +530,4 @@ Furthermore, all tests in the ```tests``` directory will be run upon pushing cha # Contact If you are interested in machine learning models for LAM, have questions about the implementation or ideas for extending it, feel free to get in touch. -There is an open [mllam slack channel](https://join.slack.com/t/ml-lam/shared_invite/zt-2t112zvm8-Vt6aBvhX7nYa6Kbj_LkCBQ) that anyone can join (after following the link you have to request to join, this is to avoid spam bots). -You can also open a github issue on this page, or (if more suitable) send an email to [joel.oskarsson@liu.se](mailto:joel.oskarsson@liu.se). +There is an open [mllam slack channel](https://join.slack.com/t/ml-lam/shared_invite/zt-2t112zvm8-Vt6aBvhX7nYa6Kbj_LkCBQ) that anyone can join. You can also open a github issue on this page, or (if more suitable) send an email to [joel.oskarsson@liu.se](mailto:joel.oskarsson@liu.se). diff --git a/neural_lam/build_rectangular_graph.py b/neural_lam/build_rectangular_graph.py new file mode 100644 index 00000000..df7f8ba8 --- /dev/null +++ b/neural_lam/build_rectangular_graph.py @@ -0,0 +1,177 @@ +# Standard library +import argparse +import os + +# Third-party +import numpy as np +import weather_model_graphs as wmg + +# Local +from . import utils +from .config import load_config_and_datastore + +WMG_ARCHETYPES = { + "keisler": wmg.create.archetype.create_keisler_graph, + "graphcast": wmg.create.archetype.create_graphcast_graph, + "hierarchical": wmg.create.archetype.create_oskarsson_hierarchical_graph, +} + + +def main(input_args=None): + parser = argparse.ArgumentParser( + description="Rectangular graph generation using weather-models-graph", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Inputs and outputs + parser.add_argument( + "--config_path", + type=str, + help="Path to the configuration for neural-lam", + ) + parser.add_argument( + "--graph_name", + type=str, + help="Name to save graph as (default: multiscale)", + ) + parser.add_argument( + "--output_dir", + type=str, + default="graphs", + help="Directory to save graph to", + ) + + # Graph structure + parser.add_argument( + "--archetype", + type=str, + default="keisler", + help="Archetype to use to create graph " + "(keisler/graphcast/hierarchical)", + ) + parser.add_argument( + "--mesh_node_distance", + type=float, + default=3.0, + help="Distance between created mesh nodes", + ) + parser.add_argument( + "--level_refinement_factor", + type=float, + default=3, + help="Refinement factor between grid points and bottom level of " + "mesh hierarchy", + ) + parser.add_argument( + "--max_num_levels", + type=int, + help="Limit multi-scale mesh to given number of levels, " + "from bottom up", + ) + args = parser.parse_args(input_args) + + assert ( + args.config_path is not None + ), "Specify your config with --config_path" + assert ( + args.graph_name is not None + ), "Specify the name to save graph as with --graph_name" + + _, datastore = load_config_and_datastore(config_path=args.config_path) + + # Load grid positions + # TODO Do not get normalised positions + coords = utils.get_reordered_grid_pos(datastore).numpy() + # (num_nodes_full, 2) + + # Construct mask + num_full_grid = coords.shape[0] + num_boundary = datastore.boundary_mask.to_numpy().sum() + num_interior = num_full_grid - num_boundary + decode_mask = np.concatenate( + ( + np.ones(num_interior, dtype=bool), + np.zeros(num_boundary, dtype=bool), + ), + axis=0, + ) + + # Build graph + assert ( + args.archetype in WMG_ARCHETYPES + ), f"Unknown archetype: {args.archetype}" + archetype_create_func = WMG_ARCHETYPES[args.archetype] + + create_kwargs = { + "coords": coords, + "mesh_node_distance": args.mesh_node_distance, + "decode_mask": decode_mask, + "return_components": True, + } + if args.archetype != "keisler": + # Add additional multi-level kwargs + create_kwargs.update( + { + "level_refinement_factor": args.level_refinement_factor, + "max_num_levels": args.max_num_levels, + } + ) + + graph_comp = archetype_create_func(**create_kwargs) + + print("Created graph:") + for name, subgraph in graph_comp.items(): + print(f"{name}: {subgraph}") + + # Save graph + graph_dir_path = os.path.join( + datastore.root_path, "graphs", args.graph_name + ) + os.makedirs(graph_dir_path, exist_ok=True) + for component, graph in graph_comp.items(): + # This seems like a bit of a hack, maybe better if saving in wmg + # was made consistent with nl + if component == "m2m": + if args.archetype == "hierarchical": + # Split by direction + m2m_direction_comp = wmg.split_graph_by_edge_attribute( + graph, attr="direction" + ) + for direction, graph in m2m_direction_comp.items(): + if direction == "same": + # Name just m2m to be consistent with non-hierarchical + wmg.save.to_pyg( + graph=graph, + name="m2m", + list_from_attribute="level", + edge_features=["len", "vdiff"], + output_directory=graph_dir_path, + ) + else: + # up and down directions + wmg.save.to_pyg( + graph=graph, + name=f"mesh_{direction}", + list_from_attribute="levels", + edge_features=["len", "vdiff"], + output_directory=graph_dir_path, + ) + else: + wmg.save.to_pyg( + graph=graph, + name=component, + list_from_attribute="dummy", # Note: Needed to output list + edge_features=["len", "vdiff"], + output_directory=graph_dir_path, + ) + else: + wmg.save.to_pyg( + graph=graph, + name=component, + edge_features=["len", "vdiff"], + output_directory=graph_dir_path, + ) + + +if __name__ == "__main__": + main() diff --git a/neural_lam/config.py b/neural_lam/config.py index d3e09697..914ebb38 100644 --- a/neural_lam/config.py +++ b/neural_lam/config.py @@ -168,4 +168,15 @@ def load_config_and_datastore( datastore_kind=config.datastore.kind, config_path=datastore_config_path ) - return config, datastore + if config.datastore_boundary is not None: + datastore_boundary_config_path = ( + Path(config_path).parent / config.datastore_boundary.config_path + ) + datastore_boundary = init_datastore( + datastore_kind=config.datastore_boundary.kind, + config_path=datastore_boundary_config_path, + ) + else: + datastore_boundary = None + + return config, datastore, datastore_boundary diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py deleted file mode 100644 index ef979be3..00000000 --- a/neural_lam/create_graph.py +++ /dev/null @@ -1,610 +0,0 @@ -# Standard library -import os -from argparse import ArgumentParser - -# Third-party -import matplotlib -import matplotlib.pyplot as plt -import networkx -import numpy as np -import scipy.spatial -import torch -import torch_geometric as pyg -from torch_geometric.utils.convert import from_networkx - -# Local -from .config import load_config_and_datastore -from .datastore.base import BaseRegularGridDatastore - - -def plot_graph(graph, title=None): - fig, axis = plt.subplots(figsize=(8, 8), dpi=200) # W,H - edge_index = graph.edge_index - pos = graph.pos - - # Fix for re-indexed edge indices only containing mesh nodes at - # higher levels in hierarchy - edge_index = edge_index - edge_index.min() - - if pyg.utils.is_undirected(edge_index): - # Keep only 1 direction of edge_index - edge_index = edge_index[:, edge_index[0] < edge_index[1]] # (2, M/2) - # TODO: indicate direction of directed edges - - # Move all to cpu and numpy, compute (in)-degrees - degrees = ( - pyg.utils.degree(edge_index[1], num_nodes=pos.shape[0]).cpu().numpy() - ) - edge_index = edge_index.cpu().numpy() - pos = pos.cpu().numpy() - - # Plot edges - from_pos = pos[edge_index[0]] # (M/2, 2) - to_pos = pos[edge_index[1]] # (M/2, 2) - edge_lines = np.stack((from_pos, to_pos), axis=1) - axis.add_collection( - matplotlib.collections.LineCollection( - edge_lines, lw=0.4, colors="black", zorder=1 - ) - ) - - # Plot nodes - node_scatter = axis.scatter( - pos[:, 0], - pos[:, 1], - c=degrees, - s=3, - marker="o", - zorder=2, - cmap="viridis", - clim=None, - ) - - plt.colorbar(node_scatter, aspect=50) - - if title is not None: - axis.set_title(title) - - return fig, axis - - -def sort_nodes_internally(nx_graph): - # For some reason the networkx .nodes() return list can not be sorted, - # but this is the ordering used by pyg when converting. - # This function fixes this. - H = networkx.DiGraph() - H.add_nodes_from(sorted(nx_graph.nodes(data=True))) - H.add_edges_from(nx_graph.edges(data=True)) - return H - - -def save_edges(graph, name, base_path): - torch.save( - graph.edge_index, os.path.join(base_path, f"{name}_edge_index.pt") - ) - edge_features = torch.cat((graph.len.unsqueeze(1), graph.vdiff), dim=1).to( - torch.float32 - ) # Save as float32 - torch.save(edge_features, os.path.join(base_path, f"{name}_features.pt")) - - -def save_edges_list(graphs, name, base_path): - torch.save( - [graph.edge_index for graph in graphs], - os.path.join(base_path, f"{name}_edge_index.pt"), - ) - edge_features = [ - torch.cat((graph.len.unsqueeze(1), graph.vdiff), dim=1).to( - torch.float32 - ) - for graph in graphs - ] # Save as float32 - torch.save(edge_features, os.path.join(base_path, f"{name}_features.pt")) - - -def from_networkx_with_start_index(nx_graph, start_index): - pyg_graph = from_networkx(nx_graph) - pyg_graph.edge_index += start_index - return pyg_graph - - -def mk_2d_graph(xy, nx, ny): - xm, xM = np.amin(xy[:, :, 0][:, 0]), np.amax(xy[:, :, 0][:, 0]) - ym, yM = np.amin(xy[:, :, 1][0, :]), np.amax(xy[:, :, 1][0, :]) - - # avoid nodes on border - dx = (xM - xm) / nx - dy = (yM - ym) / ny - lx = np.linspace(xm + dx / 2, xM - dx / 2, nx) - ly = np.linspace(ym + dy / 2, yM - dy / 2, ny) - - mg = np.meshgrid(lx, ly, indexing="ij") # Use 'ij' indexing for (Nx,Ny) - g = networkx.grid_2d_graph(len(lx), len(ly)) - - for node in g.nodes: - g.nodes[node]["pos"] = np.array([mg[0][node], mg[1][node]]) - - # add diagonal edges - g.add_edges_from( - [((x, y), (x + 1, y + 1)) for y in range(ny - 1) for x in range(nx - 1)] - + [ - ((x + 1, y), (x, y + 1)) - for y in range(ny - 1) - for x in range(nx - 1) - ] - ) - - # turn into directed graph - dg = networkx.DiGraph(g) - for u, v in g.edges(): - d = np.sqrt(np.sum((g.nodes[u]["pos"] - g.nodes[v]["pos"]) ** 2)) - dg.edges[u, v]["len"] = d - dg.edges[u, v]["vdiff"] = g.nodes[u]["pos"] - g.nodes[v]["pos"] - dg.add_edge(v, u) - dg.edges[v, u]["len"] = d - dg.edges[v, u]["vdiff"] = g.nodes[v]["pos"] - g.nodes[u]["pos"] - - return dg - - -def prepend_node_index(graph, new_index): - # Relabel node indices in graph, insert (graph_level, i, j) - ijk = [tuple((new_index,) + x) for x in graph.nodes] - to_mapping = dict(zip(graph.nodes, ijk)) - return networkx.relabel_nodes(graph, to_mapping, copy=True) - - -def create_graph( - graph_dir_path: str, - xy: np.ndarray, - n_max_levels: int, - hierarchical: bool, - create_plot: bool, -): - """ - Create graph components from `xy` grid coordinates and store in - `graph_dir_path`. - - Creates the following files for all graphs: - - g2m_edge_index.pt [2, N_g2m_edges] - - g2m_features.pt [N_g2m_edges, d_features] - - m2g_edge_index.pt [2, N_m2m_edges] - - m2g_features.pt [N_m2m_edges, d_features] - - m2m_edge_index.pt list of [2, N_m2m_edges_level], length==n_levels - - m2m_features.pt list of [N_m2m_edges_level, d_features], - length==n_levels - - mesh_features.pt list of [N_mesh_nodes_level, d_mesh_static], - length==n_levels - - where - d_features: - number of features per edge (currently d_features==3, for - edge-length, x and y) - N_g2m_edges: - number of edges in the graph from grid-to-mesh - N_m2g_edges: - number of edges in the graph from mesh-to-grid - N_m2m_edges_level: - number of edges in the graph from mesh-to-mesh at a given level - (list index corresponds to the level) - d_mesh_static: - number of static features per mesh node (currently - d_mesh_static==2, for x and y) - N_mesh_nodes_level: - number of nodes in the mesh at a given level - - And in addition for hierarchical graphs: - - mesh_up_edge_index.pt - list of [2, N_mesh_updown_edges_level], length==n_levels-1 - - mesh_up_features.pt - list of [N_mesh_updown_edges_level, d_features], length==n_levels-1 - - mesh_down_edge_index.pt - list of [2, N_mesh_updown_edges_level], length==n_levels-1 - - mesh_down_features.pt - list of [N_mesh_updown_edges_level, d_features], length==n_levels-1 - - where N_mesh_updown_edges_level is the number of edges in the graph from - mesh-to-mesh between two consecutive levels (list index corresponds index - of lower level) - - - Parameters - ---------- - graph_dir_path : str - Path to store the graph components. - xy : np.ndarray - Grid coordinates, expected to be of shape (Nx, Ny, 2). - n_max_levels : int - Limit multi-scale mesh to given number of levels, from bottom up - (default: None (no limit)). - hierarchical : bool - Generate hierarchical mesh graph (default: False). - create_plot : bool - If graphs should be plotted during generation (default: False). - - Returns - ------- - None - - """ - os.makedirs(graph_dir_path, exist_ok=True) - - print(f"Writing graph components to {graph_dir_path}") - - grid_xy = torch.tensor(xy) - pos_max = torch.max(torch.abs(grid_xy)) - - # - # Mesh - # - - # graph geometry - nx = 3 # number of children =nx**2 - nlev = int(np.log(max(xy.shape[:2])) / np.log(nx)) - nleaf = nx**nlev # leaves at the bottom = nleaf**2 - - mesh_levels = nlev - 1 - if n_max_levels: - # Limit the levels in mesh graph - mesh_levels = min(mesh_levels, n_max_levels) - - # print(f"nlev: {nlev}, nleaf: {nleaf}, mesh_levels: {mesh_levels}") - - # multi resolution tree levels - G = [] - for lev in range(1, mesh_levels + 1): - n = int(nleaf / (nx**lev)) - g = mk_2d_graph(xy, n, n) - if create_plot: - plot_graph(from_networkx(g), title=f"Mesh graph, level {lev}") - plt.show() - - G.append(g) - - if hierarchical: - # Relabel nodes of each level with level index first - G = [ - prepend_node_index(graph, level_i) - for level_i, graph in enumerate(G) - ] - - num_nodes_level = np.array([len(g_level.nodes) for g_level in G]) - # First node index in each level in the hierarchical graph - first_index_level = np.concatenate( - (np.zeros(1, dtype=int), np.cumsum(num_nodes_level[:-1])) - ) - - # Create inter-level mesh edges - up_graphs = [] - down_graphs = [] - for from_level, to_level, G_from, G_to, start_index in zip( - range(1, mesh_levels), - range(0, mesh_levels - 1), - G[1:], - G[:-1], - first_index_level[: mesh_levels - 1], - ): - # start out from graph at from level - G_down = G_from.copy() - G_down.clear_edges() - G_down = networkx.DiGraph(G_down) - - # Add nodes of to level - G_down.add_nodes_from(G_to.nodes(data=True)) - - # build kd tree for mesh point pos - # order in vm should be same as in vm_xy - v_to_list = list(G_to.nodes) - v_from_list = list(G_from.nodes) - v_from_xy = np.array([xy for _, xy in G_from.nodes.data("pos")]) - kdt_m = scipy.spatial.KDTree(v_from_xy) - - # add edges from mesh to grid - for v in v_to_list: - # find 1(?) nearest neighbours (index to vm_xy) - neigh_idx = kdt_m.query(G_down.nodes[v]["pos"], 1)[1] - u = v_from_list[neigh_idx] - - # add edge from mesh to grid - G_down.add_edge(u, v) - d = np.sqrt( - np.sum( - (G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"]) ** 2 - ) - ) - G_down.edges[u, v]["len"] = d - G_down.edges[u, v]["vdiff"] = ( - G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"] - ) - - # relabel nodes to integers (sorted) - G_down_int = networkx.convert_node_labels_to_integers( - G_down, first_label=start_index, ordering="sorted" - ) # Issue with sorting here - G_down_int = sort_nodes_internally(G_down_int) - pyg_down = from_networkx_with_start_index(G_down_int, start_index) - - # Create up graph, invert downwards edges - up_edges = torch.stack( - (pyg_down.edge_index[1], pyg_down.edge_index[0]), dim=0 - ) - pyg_up = pyg_down.clone() - pyg_up.edge_index = up_edges - - up_graphs.append(pyg_up) - down_graphs.append(pyg_down) - - if create_plot: - plot_graph( - pyg_down, title=f"Down graph, {from_level} -> {to_level}" - ) - plt.show() - - plot_graph( - pyg_down, title=f"Up graph, {to_level} -> {from_level}" - ) - plt.show() - - # Save up and down edges - save_edges_list(up_graphs, "mesh_up", graph_dir_path) - save_edges_list(down_graphs, "mesh_down", graph_dir_path) - - # Extract intra-level edges for m2m - m2m_graphs = [ - from_networkx_with_start_index( - networkx.convert_node_labels_to_integers( - level_graph, first_label=start_index, ordering="sorted" - ), - start_index, - ) - for level_graph, start_index in zip(G, first_index_level) - ] - - mesh_pos = [graph.pos.to(torch.float32) for graph in m2m_graphs] - - # For use in g2m and m2g - G_bottom_mesh = G[0] - - joint_mesh_graph = networkx.union_all([graph for graph in G]) - all_mesh_nodes = joint_mesh_graph.nodes(data=True) - - else: - # combine all levels to one graph - G_tot = G[0] - for lev in range(1, len(G)): - nodes = list(G[lev - 1].nodes) - n = int(np.sqrt(len(nodes))) - ij = ( - np.array(nodes) - .reshape((n, n, 2))[1::nx, 1::nx, :] - .reshape(int(n / nx) ** 2, 2) - ) - ij = [tuple(x) for x in ij] - G[lev] = networkx.relabel_nodes(G[lev], dict(zip(G[lev].nodes, ij))) - G_tot = networkx.compose(G_tot, G[lev]) - - # Relabel mesh nodes to start with 0 - G_tot = prepend_node_index(G_tot, 0) - - # relabel nodes to integers (sorted) - G_int = networkx.convert_node_labels_to_integers( - G_tot, first_label=0, ordering="sorted" - ) - - # Graph to use in g2m and m2g - G_bottom_mesh = G_tot - all_mesh_nodes = G_tot.nodes(data=True) - - # export the nx graph to PyTorch geometric - pyg_m2m = from_networkx(G_int) - m2m_graphs = [pyg_m2m] - mesh_pos = [pyg_m2m.pos.to(torch.float32)] - - if create_plot: - plot_graph(pyg_m2m, title="Mesh-to-mesh") - plt.show() - - # Save m2m edges - save_edges_list(m2m_graphs, "m2m", graph_dir_path) - - # Divide mesh node pos by max coordinate of grid cell - mesh_pos = [pos / pos_max for pos in mesh_pos] - - # Save mesh positions - torch.save( - mesh_pos, os.path.join(graph_dir_path, "mesh_features.pt") - ) # mesh pos, in float32 - - # - # Grid2Mesh - # - - # radius within which grid nodes are associated with a mesh node - # (in terms of mesh distance) - DM_SCALE = 0.67 - - # mesh nodes on lowest level - vm = G_bottom_mesh.nodes - vm_xy = np.array([xy for _, xy in vm.data("pos")]) - # distance between mesh nodes - dm = np.sqrt( - np.sum((vm.data("pos")[(0, 1, 0)] - vm.data("pos")[(0, 0, 0)]) ** 2) - ) - - # grid nodes - Nx, Ny = xy.shape[:2] - - G_grid = networkx.grid_2d_graph(Ny, Nx) - G_grid.clear_edges() - - # vg features (only pos introduced here) - for node in G_grid.nodes: - # pos is in feature but here explicit for convenience - G_grid.nodes[node]["pos"] = xy[ - node[1], node[0] - ] # xy is already (Nx,Ny,2) - - # add 1000 to node key to separate grid nodes (1000,i,j) from mesh nodes - # (i,j) and impose sorting order such that vm are the first nodes - G_grid = prepend_node_index(G_grid, 1000) - - # build kd tree for grid point pos - # order in vg_list should be same as in vg_xy - vg_list = list(G_grid.nodes) - vg_xy = np.array( - [xy[node[2], node[1]] for node in vg_list] - ) # xy is already (Nx,Ny,2) - kdt_g = scipy.spatial.KDTree(vg_xy) - - # now add (all) mesh nodes, include features (pos) - G_grid.add_nodes_from(all_mesh_nodes) - - # Re-create graph with sorted node indices - # Need to do sorting of nodes this way for indices to map correctly to pyg - G_g2m = networkx.Graph() - G_g2m.add_nodes_from(sorted(G_grid.nodes(data=True))) - - # turn into directed graph - G_g2m = networkx.DiGraph(G_g2m) - - # add edges - for v in vm: - # find neighbours (index to vg_xy) - neigh_idxs = kdt_g.query_ball_point(vm[v]["pos"], dm * DM_SCALE) - for i in neigh_idxs: - u = vg_list[i] - # add edge from grid to mesh - G_g2m.add_edge(u, v) - d = np.sqrt( - np.sum((G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"]) ** 2) - ) - G_g2m.edges[u, v]["len"] = d - G_g2m.edges[u, v]["vdiff"] = ( - G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"] - ) - - pyg_g2m = from_networkx(G_g2m) - - if create_plot: - plot_graph(pyg_g2m, title="Grid-to-mesh") - plt.show() - - # - # Mesh2Grid - # - - # start out from Grid2Mesh and then replace edges - G_m2g = G_g2m.copy() - G_m2g.clear_edges() - - # build kd tree for mesh point pos - # order in vm should be same as in vm_xy - vm_list = list(vm) - kdt_m = scipy.spatial.KDTree(vm_xy) - - # add edges from mesh to grid - for v in vg_list: - # find 4 nearest neighbours (index to vm_xy) - neigh_idxs = kdt_m.query(G_m2g.nodes[v]["pos"], 4)[1] - for i in neigh_idxs: - u = vm_list[i] - # add edge from mesh to grid - G_m2g.add_edge(u, v) - d = np.sqrt( - np.sum((G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"]) ** 2) - ) - G_m2g.edges[u, v]["len"] = d - G_m2g.edges[u, v]["vdiff"] = ( - G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"] - ) - - # relabel nodes to integers (sorted) - G_m2g_int = networkx.convert_node_labels_to_integers( - G_m2g, first_label=0, ordering="sorted" - ) - pyg_m2g = from_networkx(G_m2g_int) - - if create_plot: - plot_graph(pyg_m2g, title="Mesh-to-grid") - plt.show() - - # Save g2m and m2g everything - # g2m - save_edges(pyg_g2m, "g2m", graph_dir_path) - # m2g - save_edges(pyg_m2g, "m2g", graph_dir_path) - - -def create_graph_from_datastore( - datastore: BaseRegularGridDatastore, - output_root_path: str, - n_max_levels: int = None, - hierarchical: bool = False, - create_plot: bool = False, -): - if isinstance(datastore, BaseRegularGridDatastore): - xy = datastore.get_xy(category="state", stacked=False) - else: - raise NotImplementedError( - "Only graph creation for BaseRegularGridDatastore is supported" - ) - - create_graph( - graph_dir_path=output_root_path, - xy=xy, - n_max_levels=n_max_levels, - hierarchical=hierarchical, - create_plot=create_plot, - ) - - -def cli(input_args=None): - parser = ArgumentParser(description="Graph generation arguments") - parser.add_argument( - "--config_path", - type=str, - help="Path to neural-lam configuration file", - ) - parser.add_argument( - "--name", - type=str, - default="multiscale", - help="Name to save graph as (default: multiscale)", - ) - parser.add_argument( - "--plot", - action="store_true", - help="If graphs should be plotted during generation " - "(default: False)", - ) - parser.add_argument( - "--levels", - type=int, - help="Limit multi-scale mesh to given number of levels, " - "from bottom up (default: None (no limit))", - ) - parser.add_argument( - "--hierarchical", - action="store_true", - help="Generate hierarchical mesh graph (default: False)", - ) - args = parser.parse_args(input_args) - - assert ( - args.config_path is not None - ), "Specify your config with --config_path" - - # Load neural-lam configuration and datastore to use - _, datastore = load_config_and_datastore(config_path=args.config_path) - - create_graph_from_datastore( - datastore=datastore, - output_root_path=os.path.join(datastore.root_path, "graph", args.name), - n_max_levels=args.levels, - hierarchical=args.hierarchical, - create_plot=args.plot, - ) - - -if __name__ == "__main__": - cli() diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index 0317c2e5..e2d21404 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -228,23 +228,6 @@ def get_dataarray( """ pass - @cached_property - @abc.abstractmethod - def boundary_mask(self) -> xr.DataArray: - """ - Return the boundary mask for the dataset, with spatial dimensions - stacked. Where the value is 1, the grid point is a boundary point, and - where the value is 0, the grid point is not a boundary point. - - Returns - ------- - xr.DataArray - The boundary mask for the dataset, with dimensions - `('grid_index',)`. - - """ - pass - @abc.abstractmethod def get_xy(self, category: str) -> np.ndarray: """ @@ -295,8 +278,13 @@ def get_xy_extent(self, category: str) -> List[float]: The extent of the x, y coordinates. """ - xy = self.get_xy(category, stacked=False) - extent = [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()] + xy = self.get_xy(category, stacked=True) + extent = [ + xy[:, 0].min(), + xy[:, 0].max(), + xy[:, 1].min(), + xy[:, 1].max(), + ] return [float(v) for v in extent] @property diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 10593a82..0ed92129 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -1,4 +1,5 @@ # Standard library +import copy import warnings from functools import cached_property from pathlib import Path @@ -26,11 +27,10 @@ class MDPDatastore(BaseRegularGridDatastore): SHORT_NAME = "mdp" - def __init__(self, config_path, n_boundary_points=30, reuse_existing=True): + def __init__(self, config_path, reuse_existing=True): """ Construct a new MDPDatastore from the configuration file at - `config_path`. A boundary mask is created with `n_boundary_points` - boundary points. If `reuse_existing` is True, the dataset is loaded + `config_path`. If `reuse_existing` is True, the dataset is loaded from a zarr file if it exists (unless the config has been modified since the zarr was created), otherwise it is created from the configuration file. @@ -41,8 +41,6 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True): The path to the configuration file, this will be fed to the `mllam_data_prep.Config.from_yaml_file` method to then call `mllam_data_prep.create_dataset` to create the dataset. - n_boundary_points : int - The number of boundary points to use in the boundary mask. reuse_existing : bool Whether to reuse an existing dataset zarr file if it exists and its creation date is newer than the configuration file. @@ -69,7 +67,6 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True): if self._ds is None: self._ds = mdp.create_dataset(config=self._config) self._ds.to_zarr(fp_ds) - self._n_boundary_points = n_boundary_points print("The loaded datastore contains the following features:") for category in ["state", "forcing", "static"]: @@ -157,8 +154,8 @@ def get_vars_units(self, category: str) -> List[str]: The units of the variables in the given category. """ - if category not in self._ds and category == "forcing": - warnings.warn("no forcing data found in datastore") + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") return [] return self._ds[f"{category}_feature_units"].values.tolist() @@ -176,8 +173,8 @@ def get_vars_names(self, category: str) -> List[str]: The names of the variables in the given category. """ - if category not in self._ds and category == "forcing": - warnings.warn("no forcing data found in datastore") + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") return [] return self._ds[f"{category}_feature"].values.tolist() @@ -196,8 +193,8 @@ def get_vars_long_names(self, category: str) -> List[str]: The long names of the variables in the given category. """ - if category not in self._ds and category == "forcing": - warnings.warn("no forcing data found in datastore") + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") return [] return self._ds[f"{category}_feature_long_name"].values.tolist() @@ -252,9 +249,9 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray: The xarray DataArray object with processed dataset. """ - if category not in self._ds and category == "forcing": - warnings.warn("no forcing data found in datastore") - return None + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") + return [] da_category = self._ds[category] @@ -318,37 +315,6 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset: ds_stats = self._ds[stats_variables.keys()].rename(stats_variables) return ds_stats - @cached_property - def boundary_mask(self) -> xr.DataArray: - """ - Produce a 0/1 mask for the boundary points of the dataset, these will - sit at the edges of the domain (in x/y extent) and will be used to mask - out the boundary points from the loss function and to overwrite the - boundary points from the prediction. For now this is created when the - mask is requested, but in the future this could be saved to the zarr - file. - - Returns - ------- - xr.DataArray - A 0/1 mask for the boundary points of the dataset, where 1 is a - boundary point and 0 is not. - - """ - ds_unstacked = self.unstack_grid_coords(da_or_ds=self._ds) - da_state_variable = ( - ds_unstacked["state"].isel(time=0).isel(state_feature=0) - ) - da_domain_allzero = xr.zeros_like(da_state_variable) - ds_unstacked["boundary_mask"] = da_domain_allzero.isel( - x=slice(self._n_boundary_points, -self._n_boundary_points), - y=slice(self._n_boundary_points, -self._n_boundary_points), - ) - ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna( - 1 - ).astype(int) - return self.stack_grid_coords(da_or_ds=ds_unstacked.boundary_mask) - @property def coords_projection(self) -> ccrs.Projection: """ @@ -394,7 +360,9 @@ def coords_projection(self) -> ccrs.Projection: class_name = projection_info["class_name"] ProjectionClass = getattr(ccrs, class_name) - kwargs = projection_info["kwargs"] + # need to copy otherwise we modify the dict stored in the dataclass + # in-place + kwargs = copy.deepcopy(projection_info["kwargs"]) globe_kwargs = kwargs.pop("globe", {}) if len(globe_kwargs) > 0: @@ -412,12 +380,21 @@ def grid_shape_state(self): The shape of the cartesian grid for the state variables. """ - ds_state = self.unstack_grid_coords(self._ds["state"]) - da_x, da_y = ds_state.x, ds_state.y + # Boundary data often has no state features + if "state" not in self._ds: + warnings.warn( + "no state data found in datastore" + "returning grid shape from forcing data" + ) + ds_forcing = self.unstack_grid_coords(self._ds["forcing"]) + da_x, da_y = ds_forcing.x, ds_forcing.y + else: + ds_state = self.unstack_grid_coords(self._ds["state"]) + da_x, da_y = ds_state.x, ds_state.y assert da_x.ndim == da_y.ndim == 1 return CartesianGridShape(x=da_x.size, y=da_y.size) - def get_xy(self, category: str, stacked: bool) -> ndarray: + def get_xy(self, category: str, stacked: bool = True) -> ndarray: """Return the x, y coordinates of the dataset. Parameters diff --git a/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py index f2c80e8a..4207812f 100644 --- a/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py +++ b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py @@ -172,6 +172,7 @@ def main( ar_steps = 63 ds = WeatherDataset( datastore=datastore, + datastore_boundary=None, split="train", ar_steps=ar_steps, standardize=False, diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index 42e80706..dfb0b9c9 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -445,7 +445,9 @@ def _get_single_timeseries_dataarray( * np.timedelta64(1, "h") ) elif d == "analysis_time": - coord_values = self._get_analysis_times(split=split) + coord_values = self._get_analysis_times( + split=split, member_id=member + ) elif d == "y": coord_values = y elif d == "x": @@ -505,7 +507,7 @@ def _get_single_timeseries_dataarray( return da - def _get_analysis_times(self, split) -> List[np.datetime64]: + def _get_analysis_times(self, split, member_id) -> List[np.datetime64]: """Get the analysis times for the given split by parsing the filenames of all the files found for the given split. @@ -513,6 +515,8 @@ def _get_analysis_times(self, split) -> List[np.datetime64]: ---------- split : str The dataset split to get the analysis times for. + member_id : int + The ensemble member to get the analysis times for. Returns ------- @@ -520,8 +524,12 @@ def _get_analysis_times(self, split) -> List[np.datetime64]: The analysis times for the given split. """ + if member_id is None: + # Only interior state data files have member_id, to avoid duplicates + # we only look at the first member for all other categories + member_id = 0 pattern = re.sub(r"{analysis_time:[^}]*}", "*", STATE_FILENAME_FORMAT) - pattern = re.sub(r"{member_id:[^}]*}", "*", pattern) + pattern = re.sub(r"{member_id:[^}]*}", f"{member_id:03d}", pattern) sample_dir = self.root_path / "samples" / split sample_files = sample_dir.glob(pattern) @@ -606,7 +614,7 @@ def get_vars_long_names(self, category: str) -> List[str]: def get_num_data_vars(self, category: str) -> int: return len(self.get_vars_names(category=category)) - def get_xy(self, category: str, stacked: bool) -> np.ndarray: + def get_xy(self, category: str, stacked: bool = True) -> np.ndarray: """Return the x, y coordinates of the dataset. Parameters @@ -668,34 +676,6 @@ def grid_shape_state(self) -> CartesianGridShape: ny, nx = self.config.grid_shape_state return CartesianGridShape(x=nx, y=ny) - @cached_property - def boundary_mask(self) -> xr.DataArray: - """The boundary mask for the dataset. This is a binary mask that is 1 - where the grid cell is on the boundary of the domain, and 0 otherwise. - - Returns - ------- - xr.DataArray - The boundary mask for the dataset, with dimensions `[grid_index]`. - - """ - xy = self.get_xy(category="state", stacked=False) - xs = xy[:, :, 0] - ys = xy[:, :, 1] - # Check if x-coordinates are constant along columns - assert np.allclose(xs, xs[:, [0]]), "x-coordinates are not constant" - # Check if y-coordinates are constant along rows - assert np.allclose(ys, ys[[0], :]), "y-coordinates are not constant" - # Extract unique x and y coordinates - x = xs[:, 0] # Unique x-coordinates (changes along the first axis) - y = ys[0, :] # Unique y-coordinates (changes along the second axis) - values = np.load(self.root_path / "static" / "border_mask.npy") - da_mask = xr.DataArray( - values, dims=["y", "x"], coords=dict(x=x, y=y), name="boundary_mask" - ) - da_mask_stacked_xy = self.stack_grid_coords(da_mask).astype(int) - return da_mask_stacked_xy - def get_standardization_dataarray(self, category: str) -> xr.Dataset: """Return the standardization dataarray for the given category. This should contain a `{category}_mean` and `{category}_std` variable for diff --git a/neural_lam/interaction_net.py b/neural_lam/interaction_net.py index 2f45b03f..46223b88 100644 --- a/neural_lam/interaction_net.py +++ b/neural_lam/interaction_net.py @@ -30,7 +30,8 @@ def __init__( """ Create a new InteractionNet - edge_index: (2,M), Edges in pyg format + edge_index: (2,M), Edges in pyg format, with both sender and receiver + node indices starting at 0 input_dim: Dimensionality of input representations, for both nodes and edges update_edges: If new edge representations should be computed @@ -52,12 +53,11 @@ def __init__( # Default to input dim if not explicitly given hidden_dim = input_dim - # Make both sender and receiver indices of edge_index start at 0 - edge_index = edge_index - edge_index.min(dim=1, keepdim=True)[0] + # any edge_index used here must start sender and rec. nodes at index 0 # Store number of receiver nodes according to edge_index self.num_rec = edge_index[1].max() + 1 - edge_index[0] = ( - edge_index[0] + self.num_rec + edge_index = torch.stack( + (edge_index[0] + self.num_rec, edge_index[1]), dim=0 ) # Make sender indices after rec self.register_buffer("edge_index", edge_index, persistent=False) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index bc4c6719..ef766113 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -1,5 +1,6 @@ # Standard library import os +from typing import List, Union # Third-party import matplotlib.pyplot as plt @@ -7,12 +8,14 @@ import pytorch_lightning as pl import torch import wandb +import xarray as xr # Local from .. import metrics, vis from ..config import NeuralLAMConfig from ..datastore import BaseDatastore from ..loss_weighting import get_state_feature_weighting +from ..weather_dataset import WeatherDataset class ARModel(pl.LightningModule): @@ -42,18 +45,40 @@ def __init__( da_state_stats = datastore.get_standardization_dataarray( category="state" ) - da_boundary_mask = datastore.boundary_mask num_past_forcing_steps = args.num_past_forcing_steps num_future_forcing_steps = args.num_future_forcing_steps + # TODO: Set based on existing of boundary forcing datastore + # TODO: Adjust what is stored here based on self.boundary_forced + self.boundary_forced = False + + # Set up boundary mask + boundary_mask = torch.tensor( + da_boundary_mask.values, dtype=torch.float32 + ).unsqueeze( + 1 + ) # add feature dim + + self.register_buffer("boundary_mask", boundary_mask, persistent=False) + # Pre-compute interior mask for use in loss function + self.register_buffer( + "interior_mask", 1.0 - self.boundary_mask, persistent=False + ) # (num_grid_nodes, 1), 1 for non-border + # Load static features for grid/data, NB: self.predict_step assumes # dimension order to be (grid_index, static_feature) arr_static = da_static_features.transpose( "grid_index", "static_feature" ).values + static_features_torch = torch.tensor(arr_static, dtype=torch.float32) self.register_buffer( "grid_static_features", - torch.tensor(arr_static, dtype=torch.float32), + static_features_torch[self.interior_mask[:, 0].to(torch.bool)], + persistent=False, + ) + self.register_buffer( + "boundary_static_features", + static_features_torch[self.boundary_mask[:, 0].to(torch.bool)], persistent=False, ) @@ -104,29 +129,28 @@ def __init__( self.num_grid_nodes, grid_static_dim, ) = self.grid_static_features.shape - self.grid_dim = ( 2 * self.grid_output_dim + grid_static_dim - + num_forcing_vars + # Factor 2 because of temporal embedding or windowed features + + 2 + * num_forcing_vars * (num_past_forcing_steps + num_future_forcing_steps + 1) ) + if self.boundary_forced: + self.boundary_dim = self.grid_dim # TODO Compute separately + ( + self.num_boundary_nodes, + boundary_static_dim, # TODO Will need for computation below + ) = self.boundary_static_features.shape + self.num_input_nodes = self.num_grid_nodes + self.num_boundary_nodes + else: + # Only interior grid nodes + self.num_input_nodes = self.num_grid_nodes # Instantiate loss function self.loss = metrics.get_metric(args.loss) - boundary_mask = torch.tensor( - da_boundary_mask.values, dtype=torch.float32 - ).unsqueeze( - 1 - ) # add feature dim - - self.register_buffer("boundary_mask", boundary_mask, persistent=False) - # Pre-compute interior mask for use in loss function - self.register_buffer( - "interior_mask", 1.0 - self.boundary_mask, persistent=False - ) # (num_grid_nodes, 1), 1 for non-border - self.val_metrics = { "mse": [], } @@ -147,19 +171,50 @@ def __init__( # For storing spatial loss maps during evaluation self.spatial_loss_maps = [] + def _create_dataarray_from_tensor( + self, + tensor: torch.Tensor, + time: Union[int, List[int]], + split: str, + category: str, + ) -> xr.DataArray: + """ + Create an `xr.DataArray` from a tensor, with the correct dimensions and + coordinates to match the datastore used by the model. This function in + in effect is the inverse of what is returned by + `WeatherDataset.__getitem__`. + + Parameters + ---------- + tensor : torch.Tensor + The tensor to convert to a `xr.DataArray` with dimensions [time, + grid_index, feature]. The tensor will be copied to the CPU if it is + not already there. + time : Union[int,List[int]] + The time index or indices for the data, given as integers or a list + of integers representing epoch time in nanoseconds. The ints will be + copied to the CPU memory if they are not already there. + split : str + The split of the data, either 'train', 'val', or 'test' + category : str + The category of the data, either 'state' or 'forcing' + """ + # TODO: creating an instance of WeatherDataset here on every call is + # not how this should be done but whether WeatherDataset should be + # provided to ARModel or where to put plotting still needs discussion + weather_dataset = WeatherDataset(datastore=self._datastore, split=split) + time = np.array(time.cpu(), dtype="datetime64[ns]") + da = weather_dataset.create_dataarray_from_tensor( + tensor=tensor.cpu().numpy(), time=time, category=category + ) + return da + def configure_optimizers(self): opt = torch.optim.AdamW( self.parameters(), lr=self.args.lr, betas=(0.9, 0.95) ) return opt - @property - def interior_mask_bool(self): - """ - Get the interior mask as a boolean (N,) mask. - """ - return self.interior_mask[:, 0].to(torch.bool) - @staticmethod def expand_to_batch(x, batch_size): """ @@ -167,7 +222,9 @@ def expand_to_batch(x, batch_size): """ return x.unsqueeze(0).expand(batch_size, -1, -1) - def predict_step(self, prev_state, prev_prev_state, forcing): + def predict_step( + self, prev_state, prev_prev_state, forcing, boundary_forcing + ): """ Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 prev_state: (B, num_grid_nodes, feature_dim), X_t prev_prev_state: (B, @@ -176,42 +233,41 @@ def predict_step(self, prev_state, prev_prev_state, forcing): """ raise NotImplementedError("No prediction step implemented") - def unroll_prediction(self, init_states, forcing_features, true_states): + def unroll_prediction(self, init_states, forcing, boundary_forcing): """ Roll out prediction taking multiple autoregressive steps with model - init_states: (B, 2, num_grid_nodes, d_f) forcing_features: (B, - pred_steps, num_grid_nodes, d_static_f) true_states: (B, pred_steps, - num_grid_nodes, d_f) + init_states: (B, 2, num_grid_nodes, d_f) + forcing: (B, pred_steps, num_grid_nodes, d_static_f) + boundary_forcing: (B, pred_steps, num_boundary_nodes, d_boundary_f) """ prev_prev_state = init_states[:, 0] prev_state = init_states[:, 1] prediction_list = [] pred_std_list = [] - pred_steps = forcing_features.shape[1] + pred_steps = forcing.shape[1] for i in range(pred_steps): - forcing = forcing_features[:, i] - border_state = true_states[:, i] + forcing_step = forcing[:, i] + + if self.boundary_forced: + boundary_forcing_step = boundary_forcing[:, i] + else: + boundary_forcing_step = None pred_state, pred_std = self.predict_step( - prev_state, prev_prev_state, forcing + prev_state, prev_prev_state, forcing_step, boundary_forcing_step ) # state: (B, num_grid_nodes, d_f) pred_std: (B, num_grid_nodes, # d_f) or None - # Overwrite border with true state - new_state = ( - self.boundary_mask * border_state - + self.interior_mask * pred_state - ) + prediction_list.append(pred_state) - prediction_list.append(new_state) if self.output_std: pred_std_list.append(pred_std) # Update conditioning states prev_prev_state = prev_state - prev_state = new_state + prev_state = pred_state prediction = torch.stack( prediction_list, dim=1 @@ -227,16 +283,25 @@ def unroll_prediction(self, init_states, forcing_features, true_states): def common_step(self, batch): """ - Predict on single batch batch consists of: init_states: (B, 2, - num_grid_nodes, d_features) target_states: (B, pred_steps, - num_grid_nodes, d_features) forcing_features: (B, pred_steps, - num_grid_nodes, d_forcing), + Predict on single batch + batch consists of: + init_states: (B, 2, num_grid_nodes, d_features) + target_states: (B, pred_steps, num_grid_nodes, d_features) + forcing: (B, pred_steps, num_grid_nodes, d_forcing), + boundary_forcing: + (B, pred_steps, num_boundary_nodes, d_boundary_forcing), where index 0 corresponds to index 1 of init_states """ - (init_states, target_states, forcing_features, batch_times) = batch + ( + init_states, + target_states, + forcing, + boundary_forcing, + batch_times, + ) = batch prediction, pred_std = self.unroll_prediction( - init_states, forcing_features, target_states + init_states, forcing, boundary_forcing ) # (B, pred_steps, num_grid_nodes, d_f) # prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B, # pred_steps, num_grid_nodes, d_f) or (d_f,) @@ -249,12 +314,14 @@ def training_step(self, batch): """ prediction, target, pred_std, _ = self.common_step(batch) - # Compute loss + # Compute loss - mean over unrolled times and batch batch_loss = torch.mean( self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool + prediction, + target, + pred_std, ) - ) # mean over unrolled times and batch + ) log_dict = {"train_loss": batch_loss} self.log_dict( @@ -288,7 +355,9 @@ def validation_step(self, batch, batch_idx): time_step_loss = torch.mean( self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool + prediction, + target, + pred_std, ), dim=0, ) # (time_steps-1) @@ -314,7 +383,6 @@ def validation_step(self, batch, batch_idx): prediction, target, pred_std, - mask=self.interior_mask_bool, sum_vars=False, ) # (B, pred_steps, d_f) self.val_metrics["mse"].append(entry_mses) @@ -342,7 +410,9 @@ def test_step(self, batch, batch_idx): time_step_loss = torch.mean( self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool + prediction, + target, + pred_std, ), dim=0, ) # (time_steps-1,) @@ -372,16 +442,13 @@ def test_step(self, batch, batch_idx): prediction, target, pred_std, - mask=self.interior_mask_bool, sum_vars=False, ) # (B, pred_steps, d_f) self.test_metrics[metric_name].append(batch_metric_vals) if self.output_std: # Store output std. per variable, spatially averaged - mean_pred_std = torch.mean( - pred_std[..., self.interior_mask_bool, :], dim=-2 - ) # (B, pred_steps, d_f) + mean_pred_std = torch.mean(pred_std, dim=-2) # (B, pred_steps, d_f) self.test_metrics["output_std"].append(mean_pred_std) # Save per-sample spatial loss for specific times @@ -406,10 +473,13 @@ def test_step(self, batch, batch_idx): ) self.plot_examples( - batch, n_additional_examples, prediction=prediction + batch, + n_additional_examples, + prediction=prediction, + split="test", ) - def plot_examples(self, batch, n_examples, prediction=None): + def plot_examples(self, batch, n_examples, split, prediction=None): """ Plot the first n_examples forecasts from batch @@ -422,18 +492,34 @@ def plot_examples(self, batch, n_examples, prediction=None): prediction, target, _, _ = self.common_step(batch) target = batch[1] + time = batch[3] # Rescale to original data scale prediction_rescaled = prediction * self.state_std + self.state_mean target_rescaled = target * self.state_std + self.state_mean # Iterate over the examples - for pred_slice, target_slice in zip( - prediction_rescaled[:n_examples], target_rescaled[:n_examples] + for pred_slice, target_slice, time_slice in zip( + prediction_rescaled[:n_examples], + target_rescaled[:n_examples], + time[:n_examples], ): # Each slice is (pred_steps, num_grid_nodes, d_f) self.plotted_examples += 1 # Increment already here + da_prediction = self._create_dataarray_from_tensor( + tensor=pred_slice, + time=time_slice, + split=split, + category="state", + ).unstack("grid_index") + da_target = self._create_dataarray_from_tensor( + tensor=target_slice, + time=time_slice, + split=split, + category="state", + ).unstack("grid_index") + var_vmin = ( torch.minimum( pred_slice.flatten(0, 1).min(dim=0)[0], @@ -453,18 +539,20 @@ def plot_examples(self, batch, n_examples, prediction=None): var_vranges = list(zip(var_vmin, var_vmax)) # Iterate over prediction horizon time steps - for t_i, (pred_t, target_t) in enumerate( - zip(pred_slice, target_slice), start=1 - ): + for t_i, _ in enumerate(zip(pred_slice, target_slice), start=1): # Create one figure per variable at this time step var_figs = [ vis.plot_prediction( - pred=pred_t[:, var_i], - target=target_t[:, var_i], datastore=self._datastore, title=f"{var_name} ({var_unit}), " f"t={t_i} ({self._datastore.step_length * t_i} h)", vrange=var_vrange, + da_prediction=da_prediction.isel( + state_feature=var_i, time=t_i - 1 + ).squeeze(), + da_target=da_target.isel( + state_feature=var_i, time=t_i - 1 + ).squeeze(), ) for var_i, (var_name, var_unit, var_vrange) in enumerate( zip( @@ -476,6 +564,7 @@ def plot_examples(self, batch, n_examples, prediction=None): ] example_i = self.plotted_examples + wandb.log( { f"{var_name}_example_{example_i}": wandb.Image(fig) diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 6233b4d1..61c1a681 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -19,9 +19,7 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): super().__init__(args, config=config, datastore=datastore) # Load graph with static features - # NOTE: (IMPORTANT!) mesh nodes MUST have the first - # num_mesh_nodes indices, - graph_dir_path = datastore.root_path / "graph" / args.graph + graph_dir_path = datastore.root_path / "graphs" / args.graph_name self.hierarchical, graph_ldict = utils.load_graph( graph_dir_path=graph_dir_path ) @@ -49,6 +47,23 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): self.grid_embedder = utils.make_mlp( [self.grid_dim] + self.mlp_blueprint_end ) + + if self.boundary_forced: + # Define embedder for boundary nodes + # Optional separate embedder for boundary nodes + if args.shared_grid_embedder: + assert self.grid_dim == self.boundary_dim, ( + "Grid and boundary input dimension must " + "be the same when using " + f"the same embedder, got grid_dim={self.grid_dim}, " + f"boundary_dim={self.boundary_dim}" + ) + self.boundary_embedder = self.grid_embedder + else: + self.boundary_embedder = utils.make_mlp( + [self.boundary_dim] + self.mlp_blueprint_end + ) + self.g2m_embedder = utils.make_mlp([g2m_dim] + self.mlp_blueprint_end) self.m2g_embedder = utils.make_mlp([m2g_dim] + self.mlp_blueprint_end) @@ -103,12 +118,15 @@ def process_step(self, mesh_rep): """ raise NotImplementedError("process_step not implemented") - def predict_step(self, prev_state, prev_prev_state, forcing): + def predict_step( + self, prev_state, prev_prev_state, forcing, boundary_forcing + ): """ Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 prev_state: (B, num_grid_nodes, feature_dim), X_t prev_prev_state: (B, num_grid_nodes, feature_dim), X_{t-1} forcing: (B, num_grid_nodes, forcing_dim) + boundary_forcing: (B, num_boundary_nodes, boundary_forcing_dim) """ batch_size = prev_state.shape[0] @@ -123,21 +141,46 @@ def predict_step(self, prev_state, prev_prev_state, forcing): dim=-1, ) + if self.boundary_forced: + # Create full boundary node features of shape + # (B, num_boundary_nodes, boundary_dim) + boundary_features = torch.cat( + ( + boundary_forcing, + self.expand_to_batch( + self.boundary_static_features, batch_size + ), + ), + dim=-1, + ) + + # Embed boundary features + boundary_emb = self.boundary_embedder(boundary_features) + # (B, num_boundary_nodes, d_h) + # Embed all features grid_emb = self.grid_embedder(grid_features) # (B, num_grid_nodes, d_h) g2m_emb = self.g2m_embedder(self.g2m_features) # (M_g2m, d_h) m2g_emb = self.m2g_embedder(self.m2g_features) # (M_m2g, d_h) mesh_emb = self.embedd_mesh_nodes() + if self.boundary_forced: + # Merge interior and boundary emb into input embedding + # We enforce ordering (interior, boundary) of nodes + input_emb = torch.cat((grid_emb, boundary_emb), dim=1) + else: + # Only maps from interior to mesh + input_emb = grid_emb + # Map from grid to mesh mesh_emb_expanded = self.expand_to_batch( mesh_emb, batch_size ) # (B, num_mesh_nodes, d_h) g2m_emb_expanded = self.expand_to_batch(g2m_emb, batch_size) - # This also splits representation into grid and mesh + # Encode to mesh mesh_rep = self.g2m_gnn( - grid_emb, mesh_emb_expanded, g2m_emb_expanded + input_emb, mesh_emb_expanded, g2m_emb_expanded ) # (B, num_mesh_nodes, d_h) # Also MLP with residual for grid representation grid_rep = grid_emb + self.encoding_grid_mlp( diff --git a/neural_lam/plot_graph.py b/neural_lam/plot_graph.py index 999c8e53..11bd795a 100644 --- a/neural_lam/plot_graph.py +++ b/neural_lam/plot_graph.py @@ -11,25 +11,20 @@ from . import utils from .config import load_config_and_datastore -MESH_HEIGHT = 0.1 -MESH_LEVEL_DIST = 0.2 -GRID_HEIGHT = 0 - def main(): """Plot graph structure in 3D using plotly.""" parser = ArgumentParser(description="Plot graph") parser.add_argument( - "--datastore_config_path", + "--config_path", type=str, - default="tests/datastore_examples/mdp/config.yaml", - help="Path for the datastore config", + help="Path to the configuration for neural-lam", ) parser.add_argument( - "--graph", + "--graph_name", type=str, default="multiscale", - help="Graph to plot (default: multiscale)", + help="Name of saved graph to plot (default: multiscale)", ) parser.add_argument( "--save", @@ -43,16 +38,17 @@ def main(): ) args = parser.parse_args() - _, datastore = load_config_and_datastore( - config_path=args.datastore_config_path - ) - xy = datastore.get_xy("state", stacked=True) # (N_grid, 2) - pos_max = np.max(np.abs(xy)) - grid_pos = xy / pos_max # Divide by maximum coordinate + assert ( + args.config_path is not None + ), "Specify your config with --config_path" + + _, datastore = load_config_and_datastore(config_path=args.config_path) # Load graph data - graph_dir_path = os.path.join(datastore.root_path, "graph", args.graph) + graph_dir_path = os.path.join( + datastore.root_path, "graphs", args.graph_name + ) hierarchical, graph_ldict = utils.load_graph(graph_dir_path=graph_dir_path) (g2m_edge_index, m2g_edge_index, m2m_edge_index,) = ( graph_ldict["g2m_edge_index"], @@ -65,17 +61,23 @@ def main(): ) mesh_static_features = graph_ldict["mesh_static_features"] - # Add in z-dimension - z_grid = GRID_HEIGHT * np.ones((grid_pos.shape[0],)) + # Extract values needed, turn to numpy + grid_pos = utils.get_reordered_grid_pos(datastore).numpy() + grid_scale = np.ptp(grid_pos) + + # Add in z-dimension for grid + z_grid = np.zeros((grid_pos.shape[0],)) # Grid sits at z=0 grid_pos = np.concatenate( (grid_pos, np.expand_dims(z_grid, axis=1)), axis=1 ) - # List of edges to plot, (edge_index, color, line_width, label) - edge_plot_list = [ - (m2g_edge_index.numpy(), "black", 0.4, "M2G"), - (g2m_edge_index.numpy(), "black", 0.4, "G2M"), - ] + # Compute z-coordinate height of mesh nodes + mesh_base_height = 0.05 * grid_scale + mesh_level_height_diff = 0.1 * grid_scale + + # List of edges to plot, (edge_index, from_pos, to_pos, color, + # line_width, label) + edge_plot_list = [] # Mesh positioning and edges to plot differ if we have a hierarchical graph if hierarchical: @@ -83,8 +85,8 @@ def main(): np.concatenate( ( level_static_features.numpy(), - MESH_HEIGHT - + MESH_LEVEL_DIST + mesh_base_height + + mesh_level_height_diff * height_level * np.ones((level_static_features.shape[0], 1)), ), @@ -94,52 +96,118 @@ def main(): mesh_static_features, start=1 ) ] - mesh_pos = np.concatenate(mesh_level_pos, axis=0) + all_mesh_pos = np.concatenate(mesh_level_pos, axis=0) + grid_con_mesh_pos = mesh_level_pos[0] # Add inter-level mesh edges edge_plot_list += [ - (level_ei.numpy(), "blue", 1, f"M2M Level {level}") - for level, level_ei in enumerate(m2m_edge_index) + ( + level_ei.numpy(), + level_pos, + level_pos, + "blue", + 1, + f"M2M Level {level}", + ) + for level, (level_ei, level_pos) in enumerate( + zip(m2m_edge_index, mesh_level_pos) + ) ] # Add intra-level mesh edges - up_edges_ei = np.concatenate( - [level_up_ei.numpy() for level_up_ei in mesh_up_edge_index], axis=1 + up_edges_ei = [ + level_up_ei.numpy() for level_up_ei in mesh_up_edge_index + ] + down_edges_ei = [ + level_down_ei.numpy() for level_down_ei in mesh_down_edge_index + ] + # Add up edges + for level_i, (up_ei, from_pos, to_pos) in enumerate( + zip(up_edges_ei, mesh_level_pos[:-1], mesh_level_pos[1:]) + ): + edge_plot_list.append( + ( + up_ei, + from_pos, + to_pos, + "green", + 1, + f"Mesh up {level_i}-{level_i+1}", + ) + ) + # Add down edges + for level_i, (down_ei, from_pos, to_pos) in enumerate( + zip(down_edges_ei, mesh_level_pos[1:], mesh_level_pos[:-1]) + ): + edge_plot_list.append( + ( + down_ei, + from_pos, + to_pos, + "green", + 1, + f"Mesh down {level_i+1}-{level_i}", + ) + ) + + edge_plot_list.append( + ( + m2g_edge_index.numpy(), + grid_con_mesh_pos, + grid_pos, + "black", + 0.4, + "M2G", + ) ) - down_edges_ei = np.concatenate( - [level_down_ei.numpy() for level_down_ei in mesh_down_edge_index], - axis=1, + edge_plot_list.append( + ( + g2m_edge_index.numpy(), + grid_pos, + grid_con_mesh_pos, + "black", + 0.4, + "G2M", + ) ) - edge_plot_list.append((up_edges_ei, "green", 1, "Mesh up")) - edge_plot_list.append((down_edges_ei, "green", 1, "Mesh down")) mesh_node_size = 2.5 else: mesh_pos = mesh_static_features.numpy() mesh_degrees = pyg.utils.degree(m2m_edge_index[1]).numpy() - z_mesh = MESH_HEIGHT + 0.01 * mesh_degrees + # 1% higher per neighbor + z_mesh = (1 + 0.01 * mesh_degrees) * mesh_base_height mesh_node_size = mesh_degrees / 2 mesh_pos = np.concatenate( (mesh_pos, np.expand_dims(z_mesh, axis=1)), axis=1 ) - edge_plot_list.append((m2m_edge_index.numpy(), "blue", 1, "M2M")) + edge_plot_list.append( + (m2m_edge_index.numpy(), mesh_pos, mesh_pos, "blue", 1, "M2M") + ) + edge_plot_list.append( + (m2g_edge_index.numpy(), mesh_pos, grid_pos, "black", 0.4, "M2G") + ) + edge_plot_list.append( + (g2m_edge_index.numpy(), grid_pos, mesh_pos, "black", 0.4, "G2M") + ) - # All node positions in one array - node_pos = np.concatenate((mesh_pos, grid_pos), axis=0) + all_mesh_pos = mesh_pos # Add edges data_objs = [] for ( ei, + from_pos, + to_pos, col, width, label, ) in edge_plot_list: - edge_start = node_pos[ei[0]] # (M, 2) - edge_end = node_pos[ei[1]] # (M, 2) + edge_start = from_pos[ei[0]] # (M, 2) + edge_end = to_pos[ei[1]] # (M, 2) n_edges = edge_start.shape[0] x_edges = np.stack( @@ -176,9 +244,9 @@ def main(): ) data_objs.append( go.Scatter3d( - x=mesh_pos[:, 0], - y=mesh_pos[:, 1], - z=mesh_pos[:, 2], + x=all_mesh_pos[:, 0], + y=all_mesh_pos[:, 1], + z=all_mesh_pos[:, 2], mode="markers", marker={"color": "blue", "size": mesh_node_size}, name="Mesh nodes", diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 74146c89..3c2dbece 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -78,7 +78,7 @@ def main(input_args=None): # Model architecture parser.add_argument( - "--graph", + "--graph_name", type=str, default="multiscale", help="Graph to load and use in graph-based model " @@ -116,6 +116,13 @@ def main(input_args=None): "output dimensions " "(default: False (no))", ) + parser.add_argument( + "--shared_grid_embedder", + action="store_true", # Default to separate embedders + help="If the same embedder MLP should be used for interior and boundary" + " grid nodes. Note that this requires the same dimensionality for " + "both kinds of grid inputs. (default: False (no))", + ) # Training options parser.add_argument( @@ -203,6 +210,18 @@ def main(input_args=None): default=1, help="Number of future time steps to use as input for forcing data", ) + parser.add_argument( + "--num_past_boundary_steps", + type=int, + default=1, + help="Number of past time steps to use as input for boundary data", + ) + parser.add_argument( + "--num_future_boundary_steps", + type=int, + default=1, + help="Number of future time steps to use as input for boundary data", + ) args = parser.parse_args(input_args) args.var_leads_metrics_watch = { int(k): v for k, v in json.loads(args.var_leads_metrics_watch).items() @@ -226,16 +245,21 @@ def main(input_args=None): seed.seed_everything(args.seed) # Load neural-lam configuration and datastore to use - config, datastore = load_config_and_datastore(config_path=args.config_path) + config, datastore, datastore_boundary = load_config_and_datastore( + config_path=args.config_path + ) # Create datamodule data_module = WeatherDataModule( datastore=datastore, + datastore_boundary=datastore_boundary, ar_steps_train=args.ar_steps_train, ar_steps_eval=args.ar_steps_eval, standardize=True, num_past_forcing_steps=args.num_past_forcing_steps, num_future_forcing_steps=args.num_future_forcing_steps, + num_past_boundary_steps=args.num_past_boundary_steps, + num_future_boundary_steps=args.num_future_boundary_steps, batch_size=args.batch_size, num_workers=args.num_workers, ) diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 4a0752e4..6f910cee 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -33,6 +33,13 @@ def __iter__(self): return (self[i] for i in range(len(self))) +def zero_index_edge_index(edge_index): + """ + Make both sender and receiver indices of edge_index start at 0 + """ + return edge_index - edge_index.min(dim=1, keepdim=True)[0] + + def load_graph(graph_dir_path, device="cpu"): """Load all tensors representing the graph from `graph_dir_path`. @@ -71,11 +78,13 @@ def load_graph(graph_dir_path, device="cpu"): - mesh_down_edge_index - g2m_features - m2g_features - - m2m_features + - m2m_node_features - mesh_up_features - mesh_down_features - mesh_static_features + + Load all tensors representing the graph """ def loads_file(fn): @@ -85,13 +94,49 @@ def loads_file(fn): weights_only=True, ) + # Load static node features + mesh_static_features = loads_file( + "m2m_node_features.pt" + ) # List of (N_mesh[l], d_mesh_static) + # Load edges (edge_index) m2m_edge_index = BufferList( - loads_file("m2m_edge_index.pt"), persistent=False + [zero_index_edge_index(ei) for ei in loads_file("m2m_edge_index.pt")], + persistent=False, ) # List of (2, M_m2m[l]) g2m_edge_index = loads_file("g2m_edge_index.pt") # (2, M_g2m) m2g_edge_index = loads_file("m2g_edge_index.pt") # (2, M_m2g) + # Change first indices to 0 + g2m_edge_index = zero_index_edge_index(g2m_edge_index) + # m2g has to be handled specially as not all mesh nodes might be indexed in + # m2g_edge_index + m2g_min_indices = m2g_edge_index.min(dim=1, keepdim=True)[0] + if m2g_min_indices[0] < m2g_min_indices[1]: + # mesh has the first indices + # Number of mesh nodes at level that connects to grid + num_mesh_nodes = mesh_static_features[0].shape[0] + + m2g_edge_index = torch.stack( + ( + m2g_edge_index[0], + m2g_edge_index[1] - num_mesh_nodes, + ), + dim=0, + ) + else: + # grid (interior) has the first indices + # NOTE: Below works, but would be good with a better way to get this + num_interior_nodes = m2g_edge_index[1].max() + 1 + + m2g_edge_index = torch.stack( + ( + m2g_edge_index[0] - num_interior_nodes, + m2g_edge_index[1], + ), + dim=0, + ) + n_levels = len(m2m_edge_index) hierarchical = n_levels > 1 # Nor just single level mesh graph @@ -112,11 +157,6 @@ def loads_file(fn): g2m_features = g2m_features / longest_edge m2g_features = m2g_features / longest_edge - # Load static node features - mesh_static_features = loads_file( - "mesh_features.pt" - ) # List of (N_mesh[l], d_mesh_static) - # Some checks for consistency assert ( len(m2m_features) == n_levels @@ -128,10 +168,18 @@ def loads_file(fn): if hierarchical: # Load up and down edges and features mesh_up_edge_index = BufferList( - loads_file("mesh_up_edge_index.pt"), persistent=False + [ + zero_index_edge_index(ei) + for ei in loads_file("mesh_up_edge_index.pt") + ], + persistent=False, ) # List of (2, M_up[l]) mesh_down_edge_index = BufferList( - loads_file("mesh_down_edge_index.pt"), persistent=False + [ + zero_index_edge_index(ei) + for ei in loads_file("mesh_down_edge_index.pt") + ], + persistent=False, ) # List of (2, M_down[l]) mesh_up_features = loads_file( @@ -241,3 +289,24 @@ def init_wandb_metrics(wandb_logger, val_steps): experiment.define_metric("val_mean_loss", summary="min") for step in val_steps: experiment.define_metric(f"val_loss_unroll{step}", summary="min") + + +def get_reordered_grid_pos(datastore): + """ + Interior nodes first, then boundary + """ + xy_np = datastore.get_xy("state") # np, (num_grid, 2) + xy_torch = torch.tensor(xy_np, dtype=torch.float32) + + da_boundary_mask = datastore.boundary_mask + boundary_mask = torch.tensor(da_boundary_mask.values, dtype=torch.bool) + interior_mask = torch.logical_not(boundary_mask) + + return torch.cat( + ( + xy_torch[interior_mask], + xy_torch[boundary_mask], + ), + dim=0, + ) + # (num_total_grid_nodes, 2) diff --git a/neural_lam/vis.py b/neural_lam/vis.py index b9d18b39..10b84fb7 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -2,6 +2,7 @@ import matplotlib import matplotlib.pyplot as plt import numpy as np +import xarray as xr # Local from . import utils @@ -63,11 +64,47 @@ def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None): return fig +def plot_on_axis( + ax, + da, + datastore, + obs_mask=None, + vmin=None, + vmax=None, + ax_title=None, + cmap="plasma", + grid_limits=None, +): + """ + Plot weather state on given axis + """ + ax.set_global() + ax.coastlines() # Add coastline outlines + + extent = datastore.get_xy_extent("state") + + im = da.plot.imshow( + ax=ax, + origin="lower", + x="x", + extent=extent, + vmin=vmin, + vmax=vmax, + cmap=cmap, + transform=datastore.coords_projection, + ) + + if ax_title: + ax.set_title(ax_title, size=15) + + return im + + @matplotlib.rc_context(utils.fractional_plot_bundle(1)) def plot_prediction( - pred, - target, datastore: BaseRegularGridDatastore, + da_prediction: xr.DataArray = None, + da_target: xr.DataArray = None, title=None, vrange=None, ): @@ -79,20 +116,11 @@ def plot_prediction( """ # Get common scale for values if vrange is None: - vmin = min(vals.min().cpu().item() for vals in (pred, target)) - vmax = max(vals.max().cpu().item() for vals in (pred, target)) + vmin = min(da_prediction.min(), da_target.min()) + vmax = max(da_prediction.max(), da_target.max()) else: vmin, vmax = vrange - extent = datastore.get_xy_extent("state") - - # Set up masking of border region - da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) - mask_reshaped = da_mask.values - pixel_alpha = ( - mask_reshaped.clamp(0.7, 1).cpu().numpy() - ) # Faded border region - fig, axes = plt.subplots( 1, 2, @@ -101,28 +129,18 @@ def plot_prediction( ) # Plot pred and target - for ax, data in zip(axes, (target, pred)): - ax.coastlines() # Add coastline outlines - data_grid = ( - data.reshape(list(datastore.grid_shape_state.values.values())) - .cpu() - .numpy() - ) - im = ax.imshow( - data_grid, - origin="lower", - extent=extent, - alpha=pixel_alpha, + for ax, da in zip(axes, (da_target, da_prediction)): + plot_on_axis( + ax, + da, + datastore, vmin=vmin, vmax=vmax, - cmap="plasma", ) # Ticks and labels axes[0].set_title("Ground Truth", size=15) axes[1].set_title("Prediction", size=15) - cbar = fig.colorbar(im, aspect=30) - cbar.ax.tick_params(labelsize=10) if title: fig.suptitle(title, size=20) @@ -145,32 +163,25 @@ def plot_spatial_error( else: vmin, vmax = vrange - extent = datastore.get_xy_extent("state") - - # Set up masking of border region - da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) - mask_reshaped = da_mask.values - pixel_alpha = ( - mask_reshaped.clamp(0.7, 1).cpu().numpy() - ) # Faded border region - fig, ax = plt.subplots( figsize=(5, 4.8), subplot_kw={"projection": datastore.coords_projection}, ) - ax.coastlines() # Add coastline outlines error_grid = ( - error.reshape(list(datastore.grid_shape_state.values.values())) - .cpu() + error.reshape( + [datastore.grid_shape_state.x, datastore.grid_shape_state.y] + ) + .T.cpu() .numpy() ) + extent = datastore.get_xy_extent("state") + # TODO: This needs to be converted to DA and use plot_on_axis im = ax.imshow( error_grid, origin="lower", extent=extent, - alpha=pixel_alpha, vmin=vmin, vmax=vmax, cmap="OrRd", diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 532e3c90..f02cfbd4 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -22,6 +22,8 @@ class WeatherDataset(torch.utils.data.Dataset): ---------- datastore : BaseDatastore The datastore to load the data from (e.g. mdp). + datastore_boundary : BaseDatastore + The boundary datastore to load the data from (e.g. mdp). split : str, optional The data split to use ("train", "val" or "test"). Default is "train". ar_steps : int, optional @@ -36,6 +38,16 @@ class WeatherDataset(torch.utils.data.Dataset): forcing from times t, t+1, ..., t+j-1, t+j (and potentially times before t, given num_past_forcing_steps) are included as forcing inputs at time t. Default is 1. + num_past_boundary_steps: int, optional + Number of past time steps to include in boundary input. If set to i, + boundary from times t-i, t-i+1, ..., t-1, t (and potentially beyond, + given num_future_forcing_steps) are included as boundary inputs at time + t Default is 1. + num_future_boundary_steps: int, optional + Number of future time steps to include in boundary input. If set to j, + boundary from times t, t+1, ..., t+j-1, t+j (and potentially times + before t, given num_past_forcing_steps) are included as boundary inputs + at time t. Default is 1. standardize : bool, optional Whether to standardize the data. Default is True. """ @@ -43,10 +55,13 @@ class WeatherDataset(torch.utils.data.Dataset): def __init__( self, datastore: BaseDatastore, + datastore_boundary: BaseDatastore, split="train", ar_steps=3, num_past_forcing_steps=1, num_future_forcing_steps=1, + num_past_boundary_steps=1, + num_future_boundary_steps=1, standardize=True, ): super().__init__() @@ -54,15 +69,31 @@ def __init__( self.split = split self.ar_steps = ar_steps self.datastore = datastore + self.datastore_boundary = datastore_boundary self.num_past_forcing_steps = num_past_forcing_steps self.num_future_forcing_steps = num_future_forcing_steps + self.num_past_boundary_steps = num_past_boundary_steps + self.num_future_boundary_steps = num_future_boundary_steps self.da_state = self.datastore.get_dataarray( category="state", split=self.split ) + if self.da_state is None: + raise ValueError( + "A non-empty state dataarray must be provided. " + "The datastore.get_dataarray() returned None or empty array " + "for category='state'" + ) self.da_forcing = self.datastore.get_dataarray( category="forcing", split=self.split ) + # XXX For now boundary data is always considered mdp-forcing data + if self.datastore_boundary is not None: + self.da_boundary = self.datastore_boundary.get_dataarray( + category="forcing", split=self.split + ) + else: + self.da_boundary = None # check that with the provided data-arrays and ar_steps that we have a # non-zero amount of samples @@ -94,6 +125,85 @@ def __init__( "the data in `BaseDatastore.get_dataarray`?" ) + # handling ensemble data + if self.datastore.is_ensemble: + # for the now the strategy is to only include the first ensemble + # member + # XXX: this could be changed to include all ensemble members by + # splitting `idx` into two parts, one for the analysis time and one + # for the ensemble member and then increasing self.__len__ to + # include all ensemble members + warnings.warn( + "only use of ensemble member 0 (the first member) is " + "implemented for ensemble data" + ) + i_ensemble = 0 + self.da_state = self.da_state.isel(ensemble_member=i_ensemble) + else: + self.da_state = self.da_state + + # Check time step consistency in state data + if self.datastore.is_forecast: + state_times = self.da_state.analysis_time + else: + state_times = self.da_state.time + _ = self._get_time_step(state_times) + + # Check time coverage for forcing and boundary data + if self.da_forcing is not None or self.da_boundary is not None: + if self.datastore.is_forecast: + state_times = self.da_state.analysis_time + else: + state_times = self.da_state.time + state_time_min = state_times.min().values + state_time_max = state_times.max().values + + if self.da_forcing is not None: + # Forcing data is part of the same datastore as state data + # During creation the time dimension of the forcing data + # is matched to the state data + if self.datastore.is_forecast: + forcing_times = self.da_forcing.analysis_time + else: + forcing_times = self.da_forcing.time + self._get_time_step(forcing_times.values) + + if self.da_boundary is not None: + # Boundary data is part of a separate datastore + # The boundary data is allowed to have a different time_step + # Check that the boundary data covers the required time range + if self.datastore_boundary.is_forecast: + boundary_times = self.da_boundary.analysis_time + else: + boundary_times = self.da_boundary.time + boundary_time_step = self._get_time_step(boundary_times.values) + boundary_time_min = boundary_times.min().values + boundary_time_max = boundary_times.max().values + + # Calculate required bounds for boundary using its time step + boundary_required_time_min = ( + state_time_min + - self.num_past_forcing_steps * boundary_time_step + ) + boundary_required_time_max = ( + state_time_max + + self.num_future_forcing_steps * boundary_time_step + ) + + if boundary_time_min > boundary_required_time_min: + raise ValueError( + f"Boundary data starts too late." + f"Required start: {boundary_required_time_min}, " + f"but boundary starts at {boundary_time_min}." + ) + + if boundary_time_max < boundary_required_time_max: + raise ValueError( + f"Boundary data ends too early." + f"Required end: {boundary_required_time_max}, " + f"but boundary ends at {boundary_time_max}." + ) + # Set up for standardization # TODO: This will become part of ar_model.py soon! self.standardize = standardize @@ -114,6 +224,16 @@ def __init__( self.da_forcing_mean = self.ds_forcing_stats.forcing_mean self.da_forcing_std = self.ds_forcing_stats.forcing_std + # XXX: Again, the boundary data is considered forcing data for now + if self.da_boundary is not None: + self.ds_boundary_stats = ( + self.datastore_boundary.get_standardization_dataarray( + category="forcing" + ) + ) + self.da_boundary_mean = self.ds_boundary_stats.forcing_mean + self.da_boundary_std = self.ds_boundary_stats.forcing_std + def __len__(self): if self.datastore.is_forecast: # for now we simply create a single sample for each analysis time @@ -126,7 +246,7 @@ def __len__(self): warnings.warn( "only using first ensemble member, so dataset size is " " effectively reduced by the number of ensemble members " - f"({self.da_state.ensemble_member.size})", + f"({self.datastore._num_ensemble_members})", UserWarning, ) @@ -160,32 +280,67 @@ def __len__(self): - self.num_future_forcing_steps ) - def _slice_state_time(self, da_state, idx, n_steps: int): + def _get_time_step(self, times): + """Calculate the time step from the data + + Parameters + ---------- + times : xr.DataArray + The time dataarray to calculate the time step from. """ - Produce a time slice of the given dataarray `da_state` (state) starting - at `idx` and with `n_steps` steps. An `offset`is calculated based on the - `num_past_forcing_steps` class attribute. `Offset` is used to offset the - start of the sample, to assert that enough previous time steps are - available for the 2 initial states and any corresponding forcings - (calculated in `_slice_forcing_time`). + time_diffs = np.diff(times) + if not np.all(time_diffs == time_diffs[0]): + raise ValueError( + "Inconsistent time steps in data. " + f"Found different time steps: {np.unique(time_diffs)}" + ) + return time_diffs[0] + + def _slice_time( + self, + da_state, + idx, + n_steps: int, + da_forcing_boundary=None, + num_past_steps=None, + num_future_steps=None, + ): + """ + Produce time slices of the given dataarrays `da_state` (state) and + `da_forcing_boundary`. For the state data, slicing is done + based on `idx`. For the forcing/boundary data, nearest neighbor matching + is performed based on the state times. Additionally, the time difference + between the matched forcing/boundary times and state times (in multiples + of state time steps) is added to the forcing dataarray. This will be + used as an additional input feature in the model (temporal embedding). Parameters ---------- da_state : xr.DataArray - The dataarray to slice. This is expected to have a `time` dimension - if the datastore is providing analysis only data, and a - `analysis_time` and `elapsed_forecast_duration` dimensions if the - datastore is providing forecast data. + The state dataarray to slice. idx : int - The index of the time step to start the sample from. + The index of the time step to start the sample from in the state + data. n_steps : int The number of time steps to include in the sample. + da_forcing_boundary : xr.DataArray + The forcing/boundary dataarray to slice. + num_past_steps : int, optional + The number of past time steps to include in the forcing/boundary + data. Default is `None`. + num_future_steps : int, optional + The number of future time steps to include in the forcing/boundary + data. Default is `None`. Returns ------- - da_sliced : xr.DataArray - The sliced dataarray with dims ('time', 'grid_index', + da_state_sliced : xr.DataArray + The sliced state dataarray with dims ('time', 'grid_index', 'state_feature'). + da_forcing_boundary_matched : xr.DataArray + The sliced state dataarray with dims ('time', 'grid_index', + 'forcing/boundary_feature_windowed'). + If no forcing/boundary data is provided, this will be `None`. """ # The current implementation requires at least 2 time steps for the # initial state (see GraphCast). @@ -199,82 +354,64 @@ def _slice_state_time(self, da_state, idx, n_steps: int): # simply select a analysis time and the first `n_steps` forecast # times (given no offset). Note that this means that we get one # sample per forecast, always starting at forecast time 2. - da_sliced = da_state.isel( + da_state_sliced = da_state.isel( analysis_time=idx, elapsed_forecast_duration=slice(start_idx, end_idx), ) # create a new time dimension so that the produced sample has a # `time` dimension, similarly to the analysis only data - da_sliced["time"] = ( - da_sliced.analysis_time + da_sliced.elapsed_forecast_duration + da_state_sliced["time"] = ( + da_state_sliced.analysis_time + + da_state_sliced.elapsed_forecast_duration ) - da_sliced = da_sliced.swap_dims( + da_state_sliced = da_state_sliced.swap_dims( {"elapsed_forecast_duration": "time"} ) + # Asserting that the forecast time step is consistent + self._get_time_step(da_state_sliced.time) + else: # For analysis data we slice the time dimension directly. The offset # is only relevant for the very first (and last) samples in the # dataset. - start_idx = idx + max(0, self.num_past_forcing_steps - init_steps) - end_idx = ( - idx + max(init_steps, self.num_past_forcing_steps) + n_steps - ) - da_sliced = da_state.isel(time=slice(start_idx, end_idx)) - return da_sliced - - def _slice_forcing_time(self, da_forcing, idx, n_steps: int): - """ - Produce a time slice of the given dataarray `da_forcing` (forcing) - starting at `idx` and with `n_steps` steps. An `offset` is calculated - based on the `num_past_forcing_steps` class attribute. It is used to - offset the start of the sample, to ensure that enough previous time - steps are available for the forcing data. The forcing data is windowed - around the current autoregressive time step to include the past and - future forcings. - - Parameters - ---------- - da_forcing : xr.DataArray - The forcing dataarray to slice. This is expected to have a `time` - dimension if the datastore is providing analysis only data, and a - `analysis_time` and `elapsed_forecast_duration` dimensions if the - datastore is providing forecast data. - idx : int - The index of the time step to start the sample from. - n_steps : int - The number of time steps to include in the sample. - - Returns - ------- - da_concat : xr.DataArray - The sliced dataarray with dims ('time', 'grid_index', - 'window', 'forcing_feature'). - """ - # The current implementation requires at least 2 time steps for the - # initial state (see GraphCast). The forcing data is windowed around the - # current autregressive time step. The two `init_steps` can also be used - # as past forcings. - init_steps = 2 - da_list = [] - - if self.datastore.is_forecast: - # This implies that the data will have both `analysis_time` and - # `elapsed_forecast_duration` dimensions for forecasts. We for now - # simply select an analysis time and the first `n_steps` forecast - # times (given no offset). Note that this means that we get one - # sample per forecast. + start_idx = idx + max(0, num_past_steps - init_steps) + end_idx = idx + max(init_steps, num_past_steps) + n_steps + da_state_sliced = da_state.isel(time=slice(start_idx, end_idx)) + + if da_forcing_boundary is None: + return da_state_sliced, None + + # Get the state times and its temporal resolution for matching with + # forcing data + state_times = da_state_sliced["time"] + state_time_step = state_times.values[1] - state_times.values[0] + + # Here we cannot check 'self.datastore.is_forecast' directly because we + # might be dealing with a datastore_boundary + if "analysis_time" in da_forcing_boundary.dims: + # Select the closest analysis time in the forcing/boundary data + # This is mostly relevant for boundary data where the time steps + # are not necessarily the same as the state data. But still fast + # enough for forcing data where the time steps are the same. + idx = np.abs( + da_forcing_boundary.analysis_time.values + - self.da_state.analysis_time.values[idx] + ).argmin() # Add a 'time' dimension using the actual forecast times - offset = max(init_steps, self.num_past_forcing_steps) + offset = max(init_steps, num_past_steps) + da_list = [] for step in range(n_steps): - start_idx = offset + step - self.num_past_forcing_steps - end_idx = offset + step + self.num_future_forcing_steps + start_idx = offset + step - num_past_steps + end_idx = offset + step + num_future_steps current_time = ( - da_forcing.analysis_time[idx] - + da_forcing.elapsed_forecast_duration[offset + step] + da_forcing_boundary.analysis_time[idx] + + da_forcing_boundary.elapsed_forecast_duration[ + offset + step + ] ) - da_sliced = da_forcing.isel( + da_sliced = da_forcing_boundary.isel( analysis_time=idx, elapsed_forecast_duration=slice(start_idx, end_idx + 1), ) @@ -285,7 +422,7 @@ def _slice_forcing_time(self, da_forcing, idx, n_steps: int): # Assign the 'window' coordinate to be relative positions da_sliced = da_sliced.assign_coords( - window=np.arange(len(da_sliced.window)) + window=np.arange(-num_past_steps, num_future_steps + 1) ) da_sliced = da_sliced.expand_dims( @@ -294,46 +431,115 @@ def _slice_forcing_time(self, da_forcing, idx, n_steps: int): da_list.append(da_sliced) - # Concatenate the list of DataArrays along the 'time' dimension - da_concat = xr.concat(da_list, dim="time") + # Generate temporal embedding `time_diff_steps` for the + # forcing/boundary data. This is the time difference in multiples + # of state time steps between the forcing/boundary time and the + # state time. + da_forcing_boundary_matched = xr.concat(da_list, dim="time") + forcing_time_step = ( + da_forcing_boundary_matched.time.values[1] + - da_forcing_boundary_matched.time.values[0] + ) + da_forcing_boundary_matched["window"] = da_forcing_boundary_matched[ + "window" + ] * (forcing_time_step / state_time_step) + time_diff_steps = da_forcing_boundary_matched.isel( + grid_index=0, forcing_feature=0 + ).data else: # For analysis data, we slice the time dimension directly. The # offset is only relevant for the very first (and last) samples in # the dataset. - offset = idx + max(init_steps, self.num_past_forcing_steps) - for step in range(n_steps): - start_idx = offset + step - self.num_past_forcing_steps - end_idx = offset + step + self.num_future_forcing_steps + forcing_times = da_forcing_boundary["time"] + + # Compute time differences between forcing and state times + # (in multiples of state time steps) + # Retrieve the indices of the closest times in the forcing data + time_deltas = ( + forcing_times.values[:, np.newaxis] + - state_times.values[np.newaxis, :] + ) / state_time_step + idx_min = np.abs(time_deltas).argmin(axis=0) + + time_diff_steps = np.stack( + [ + time_deltas[ + idx_i - num_past_steps : idx_i + num_future_steps + 1, + init_steps + step_i, + ] + for (step_i, idx_i) in enumerate(idx_min[init_steps:]) + ], + ) - # Slice the data over the desired time window - da_sliced = da_forcing.isel(time=slice(start_idx, end_idx + 1)) + # Create window dimension for forcing data to stack later + window_size = num_past_steps + num_future_steps + 1 + da_forcing_boundary_windowed = da_forcing_boundary.rolling( + time=window_size, center=False + ).construct(window_dim="window") + da_forcing_boundary_matched = da_forcing_boundary_windowed.isel( + time=idx_min[init_steps:] + ) - da_sliced = da_sliced.rename({"time": "window"}) + # Add time difference as a new coordinate to concatenate to the + # forcing features later as temporal embedding + da_forcing_boundary_matched["time_diff_steps"] = ( + ("time", "window"), + time_diff_steps, + ) - # Assign the 'window' coordinate to be relative positions - da_sliced = da_sliced.assign_coords( - window=np.arange(len(da_sliced.window)) - ) + return da_state_sliced, da_forcing_boundary_matched - # Add a 'time' dimension to keep track of steps using actual - # time coordinates - current_time = da_forcing.time[offset + step] - da_sliced = da_sliced.expand_dims( - dim={"time": [current_time.values]} - ) + def _process_windowed_data(self, da_windowed, da_state, da_target_times): + """Helper function to process windowed data. This function stacks the + 'forcing_feature' and 'window' dimensions and adds the time step + differences to the existing features as a temporal embedding. - da_list.append(da_sliced) + Parameters + ---------- + da_windowed : xr.DataArray + The windowed data to process. Can be `None` if no data is provided. + da_state : xr.DataArray + The state dataarray. + da_target_times : xr.DataArray + The target times. - # Concatenate the list of DataArrays along the 'time' dimension - da_concat = xr.concat(da_list, dim="time") + Returns + ------- + da_windowed : xr.DataArray + The processed windowed data. If `da_windowed` is `None`, an empty + DataArray with the correct dimensions and coordinates is returned. - return da_concat + """ + stacked_dim = "forcing_feature_windowed" + if da_windowed is not None: + # Stack the 'feature' and 'window' dimensions and add the + # time step differences to the existing features as a temporal + # embedding + da_windowed = da_windowed.stack( + {stacked_dim: ("forcing_feature", "window")} + ) + da_windowed = xr.concat( + [da_windowed, da_windowed.time_diff_steps], + dim="forcing_feature_windowed", + ) + else: + # Create empty DataArray with the correct dimensions and coordinates + da_windowed = xr.DataArray( + data=np.empty((self.ar_steps, da_state.grid_index.size, 0)), + dims=("time", "grid_index", f"{stacked_dim}"), + coords={ + "time": da_target_times, + "grid_index": da_state.grid_index, + f"{stacked_dim}": [], + }, + ) + return da_windowed def _build_item_dataarrays(self, idx): """ - Create the dataarrays for the initial states, target states and forcing - data for the sample at index `idx`. + Create the dataarrays for the initial states, target states, forcing + and boundary data for the sample at index `idx`. Parameters ---------- @@ -348,26 +554,13 @@ def _build_item_dataarrays(self, idx): The dataarray for the target states. da_forcing_windowed : xr.DataArray The dataarray for the forcing data, windowed for the sample. + da_boundary_windowed : xr.DataArray + The dataarray for the boundary data, windowed for the sample. + Boundary data is always considered forcing data. da_target_times : xr.DataArray The dataarray for the target times. """ - # handling ensemble data - if self.datastore.is_ensemble: - # for the now the strategy is to only include the first ensemble - # member - # XXX: this could be changed to include all ensemble members by - # splitting `idx` into two parts, one for the analysis time and one - # for the ensemble member and then increasing self.__len__ to - # include all ensemble members - warnings.warn( - "only use of ensemble member 0 (the first member) is " - "implemented for ensemble data" - ) - i_ensemble = 0 - da_state = self.da_state.isel(ensemble_member=i_ensemble) - else: - da_state = self.da_state - + da_state = self.da_state if self.da_forcing is not None: if "ensemble_member" in self.da_forcing.dims: raise NotImplementedError( @@ -377,20 +570,42 @@ def _build_item_dataarrays(self, idx): else: da_forcing = None - # handle time sampling in a way that is compatible with both analysis - # and forecast data - da_state = self._slice_state_time( - da_state=da_state, idx=idx, n_steps=self.ar_steps - ) - if da_forcing is not None: - da_forcing_windowed = self._slice_forcing_time( - da_forcing=da_forcing, idx=idx, n_steps=self.ar_steps + if self.da_boundary is not None: + da_boundary = self.da_boundary + else: + da_boundary = None + + # if da_forcing_boundary is None, the function will return None for + # da_forcing_windowed + if da_boundary is not None: + _, da_boundary_windowed = self._slice_time( + da_state=da_state, + idx=idx, + n_steps=self.ar_steps, + da_forcing_boundary=da_boundary, + num_future_steps=self.num_future_boundary_steps, + num_past_steps=self.num_past_boundary_steps, ) + else: + da_boundary_windowed = None + # XXX: Currently, the order of the `slice_time` calls is important + # as `da_state` is modified in the second call. This should be + # refactored to be more robust. + da_state, da_forcing_windowed = self._slice_time( + da_state=da_state, + idx=idx, + n_steps=self.ar_steps, + da_forcing_boundary=da_forcing, + num_future_steps=self.num_future_forcing_steps, + num_past_steps=self.num_past_forcing_steps, + ) # load the data into memory da_state.load() if da_forcing is not None: da_forcing_windowed.load() + if da_boundary is not None: + da_boundary_windowed.load() da_init_states = da_state.isel(time=slice(0, 2)) da_target_states = da_state.isel(time=slice(2, None)) @@ -413,30 +628,27 @@ def _build_item_dataarrays(self, idx): da_forcing_windowed - self.da_forcing_mean ) / self.da_forcing_std - if da_forcing is not None: - # stack the `forcing_feature` and `window_sample` dimensions into a - # single `forcing_feature` dimension - da_forcing_windowed = da_forcing_windowed.stack( - forcing_feature_windowed=("forcing_feature", "window") - ) - else: - # create an empty forcing tensor with the right shape - da_forcing_windowed = xr.DataArray( - data=np.empty( - (self.ar_steps, da_state.grid_index.size, 0), - ), - dims=("time", "grid_index", "forcing_feature"), - coords={ - "time": da_target_times, - "grid_index": da_state.grid_index, - "forcing_feature": [], - }, - ) + if da_boundary is not None: + da_boundary_windowed = ( + da_boundary_windowed - self.da_boundary_mean + ) / self.da_boundary_std + + # This function handles the stacking of the forcing and boundary data + # and adds the time step differences as a temporal embedding. + # It can handle `None` inputs for the forcing and boundary data + # (and simlpy return an empty DataArray in that case). + da_forcing_windowed = self._process_windowed_data( + da_forcing_windowed, da_state, da_target_times + ) + da_boundary_windowed = self._process_windowed_data( + da_boundary_windowed, da_state, da_target_times + ) return ( da_init_states, da_target_states, da_forcing_windowed, + da_boundary_windowed, da_target_times, ) @@ -471,6 +683,7 @@ def __getitem__(self, idx): da_init_states, da_target_states, da_forcing_windowed, + da_boundary_windowed, da_target_times, ) = self._build_item_dataarrays(idx=idx) @@ -487,13 +700,20 @@ def __getitem__(self, idx): ) forcing = torch.tensor(da_forcing_windowed.values, dtype=tensor_dtype) + boundary = torch.tensor(da_boundary_windowed.values, dtype=tensor_dtype) # init_states: (2, N_grid, d_features) # target_states: (ar_steps, N_grid, d_features) # forcing: (ar_steps, N_grid, d_windowed_forcing) + # boundary: (ar_steps, N_grid, d_windowed_boundary) # target_times: (ar_steps,) - return init_states, target_states, forcing, target_times + # Assert that the boundary data is an empty tensor if the corresponding + # datastore_boundary is `None` + if self.datastore_boundary is None: + assert boundary.numel() == 0 + + return init_states, target_states, forcing, boundary, target_times def __iter__(self): """ @@ -529,7 +749,8 @@ def create_dataarray_from_tensor( tensor : torch.Tensor The tensor to construct the DataArray from, this assumed to have the same dimension ordering as returned by the __getitem__ method - (i.e. time, grid_index, {category}_feature). + (i.e. time, grid_index, {category}_feature). The tensor will be + copied to the CPU before constructing the DataArray. time : datetime.datetime or list[datetime.datetime] The time or times of the tensor. category : str @@ -581,7 +802,7 @@ def _is_listlike(obj): coords["time"] = time da = xr.DataArray( - tensor.numpy(), + tensor.cpu().numpy(), dims=dims, coords=coords, ) @@ -605,18 +826,24 @@ class WeatherDataModule(pl.LightningDataModule): def __init__( self, datastore: BaseDatastore, + datastore_boundary: BaseDatastore, ar_steps_train=3, ar_steps_eval=25, standardize=True, num_past_forcing_steps=1, num_future_forcing_steps=1, + num_past_boundary_steps=1, + num_future_boundary_steps=1, batch_size=4, num_workers=16, ): super().__init__() self._datastore = datastore + self._datastore_boundary = datastore_boundary self.num_past_forcing_steps = num_past_forcing_steps self.num_future_forcing_steps = num_future_forcing_steps + self.num_past_boundary_steps = num_past_boundary_steps + self.num_future_boundary_steps = num_future_boundary_steps self.ar_steps_train = ar_steps_train self.ar_steps_eval = ar_steps_eval self.standardize = standardize @@ -626,8 +853,10 @@ def __init__( self.val_dataset = None self.test_dataset = None if num_workers > 0: - # default to spawn for now, as the default on linux "fork" hangs - # when using dask (which the npyfilesmeps datastore uses) + # BUG: There also seem to be issues with "spawn" and `gloo`, to be + # investigated. Defaults to spawn for now, as the default on linux + # "fork" hangs when using dask (which the npyfilesmeps datastore + # uses) self.multiprocessing_context = "spawn" else: self.multiprocessing_context = None @@ -636,29 +865,38 @@ def setup(self, stage=None): if stage == "fit" or stage is None: self.train_dataset = WeatherDataset( datastore=self._datastore, + datastore_boundary=self._datastore_boundary, split="train", ar_steps=self.ar_steps_train, standardize=self.standardize, num_past_forcing_steps=self.num_past_forcing_steps, num_future_forcing_steps=self.num_future_forcing_steps, + num_past_boundary_steps=self.num_past_boundary_steps, + num_future_boundary_steps=self.num_future_boundary_steps, ) self.val_dataset = WeatherDataset( datastore=self._datastore, + datastore_boundary=self._datastore_boundary, split="val", ar_steps=self.ar_steps_eval, standardize=self.standardize, num_past_forcing_steps=self.num_past_forcing_steps, num_future_forcing_steps=self.num_future_forcing_steps, + num_past_boundary_steps=self.num_past_boundary_steps, + num_future_boundary_steps=self.num_future_boundary_steps, ) if stage == "test" or stage is None: self.test_dataset = WeatherDataset( datastore=self._datastore, + datastore_boundary=self._datastore_boundary, split="test", ar_steps=self.ar_steps_eval, standardize=self.standardize, num_past_forcing_steps=self.num_past_forcing_steps, num_future_forcing_steps=self.num_future_forcing_steps, + num_past_boundary_steps=self.num_past_boundary_steps, + num_future_boundary_steps=self.num_future_boundary_steps, ) def train_dataloader(self): diff --git a/pyproject.toml b/pyproject.toml index f0bc0851..9607d1da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,13 +25,14 @@ dependencies = [ "torch>=2.3.0", "torch-geometric==2.3.1", "parse>=1.20.2", - "dataclass-wizard>=0.22.3", + "dataclass-wizard<0.31.0", "mllam-data-prep>=0.5.0", + "weather-model-graphs>=0.2.0" ] requires-python = ">=3.9" [project.optional-dependencies] -dev = ["pre-commit>=3.8.0", "pytest>=8.3.2", "pooch>=1.8.2"] +dev = ["pre-commit>=3.8.0", "pytest>=8.3.2", "pooch>=1.8.2", "gcsfs>=2021.10.0"] [tool.setuptools] py-modules = ["neural_lam"] diff --git a/tests/conftest.py b/tests/conftest.py index 6f579621..90a86d0d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -94,6 +94,15 @@ def download_meps_example_reduced_dataset(): dummydata=None, ) +DATASTORES_BOUNDARY_EXAMPLES = { + "mdp": ( + DATASTORE_EXAMPLES_ROOT_PATH + / "mdp" + / "era5_1000hPa_danra_100m_winds" + / "era5.datastore.yaml" + ), +} + DATASTORES[DummyDatastore.SHORT_NAME] = DummyDatastore @@ -102,5 +111,13 @@ def init_datastore_example(datastore_kind): datastore_kind=datastore_kind, config_path=DATASTORES_EXAMPLES[datastore_kind], ) - return datastore + + +def init_datastore_boundary_example(datastore_kind): + datastore_boundary = init_datastore( + datastore_kind=datastore_kind, + config_path=DATASTORES_BOUNDARY_EXAMPLES[datastore_kind], + ) + + return datastore_boundary diff --git a/tests/datastore_examples/.gitignore b/tests/datastore_examples/.gitignore index e84e6493..4fbd2326 100644 --- a/tests/datastore_examples/.gitignore +++ b/tests/datastore_examples/.gitignore @@ -1,2 +1,3 @@ npyfilesmeps/*.zip -npyfilesmeps/meps_example_reduced/ +npyfilesmeps/meps_example_reduced +npyfilesmeps/era5_1000hPa_temp_meps_example_reduced diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/.gitignore b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/.gitignore new file mode 100644 index 00000000..f2828f46 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/.gitignore @@ -0,0 +1,2 @@ +*.zarr/ +graph/ diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/config.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/config.yaml new file mode 100644 index 00000000..a158bee3 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/config.yaml @@ -0,0 +1,12 @@ +datastore: + kind: mdp + config_path: danra.datastore.yaml +datastore_boundary: + kind: mdp + config_path: era5.datastore.yaml +training: + state_feature_weighting: + __config_class__: ManualStateFeatureWeighting + weights: + u100m: 1.0 + v100m: 1.0 diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/danra.datastore.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/danra.datastore.yaml new file mode 100644 index 00000000..3edf1267 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/danra.datastore.yaml @@ -0,0 +1,99 @@ +schema_version: v0.5.0 +dataset_version: v0.1.0 + +output: + variables: + static: [grid_index, static_feature] + state: [time, grid_index, state_feature] + forcing: [time, grid_index, forcing_feature] + coord_ranges: + time: + start: 1990-09-03T00:00 + end: 1990-09-09T00:00 + step: PT3H + chunking: + time: 1 + splitting: + dim: time + splits: + train: + start: 1990-09-03T00:00 + end: 1990-09-06T00:00 + compute_statistics: + ops: [mean, std, diff_mean, diff_std] + dims: [grid_index, time] + val: + start: 1990-09-06T00:00 + end: 1990-09-07T00:00 + test: + start: 1990-09-07T00:00 + end: 1990-09-09T00:00 + +inputs: + danra_height_levels: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr + dims: [time, x, y, altitude] + variables: + u: + altitude: + values: [100,] + units: m + v: + altitude: + values: [100, ] + units: m + dim_mapping: + time: + method: rename + dim: time + state_feature: + method: stack_variables_by_var_name + dims: [altitude] + name_format: "{var_name}{altitude}m" + grid_index: + method: stack + dims: [x, y] + target_output_variable: state + + danra_surface: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr + dims: [time, x, y] + variables: + # use surface incoming shortwave radiation as forcing + - swavr0m + dim_mapping: + time: + method: rename + dim: time + grid_index: + method: stack + dims: [x, y] + forcing_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + target_output_variable: forcing + + danra_lsm: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/lsm.zarr + dims: [x, y] + variables: + - lsm + dim_mapping: + grid_index: + method: stack + dims: [x, y] + static_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + target_output_variable: static + +extra: + projection: + class_name: LambertConformal + kwargs: + central_longitude: 25.0 + central_latitude: 56.7 + standard_parallels: [56.7, 56.7] + globe: + semimajor_axis: 6367470.0 + semiminor_axis: 6367470.0 diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml new file mode 100644 index 00000000..7c5ffb3b --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml @@ -0,0 +1,85 @@ +schema_version: v0.5.0 +dataset_version: v1.0.0 + +output: + variables: + forcing: [time, grid_index, forcing_feature] + coord_ranges: + time: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + step: PT6H + chunking: + time: 1 + splitting: + dim: time + splits: + train: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + compute_statistics: + ops: [mean, std, diff_mean, diff_std] + dims: [grid_index, time] + val: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + test: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + +inputs: + era_height_levels: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr' + dims: [time, longitude, latitude, level] + variables: + u_component_of_wind: + level: + values: [1000,] + units: hPa + dim_mapping: + time: + method: rename + dim: time + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + forcing_feature: + method: stack_variables_by_var_name + dims: [level] + name_format: "{var_name}{level}hPa" + grid_index: + method: stack + dims: [x, y] + target_output_variable: forcing + + era5_surface: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr' + dims: [time, longitude, latitude, level] + variables: + - mean_surface_net_short_wave_radiation_flux + dim_mapping: + time: + method: rename + dim: time + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + forcing_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + grid_index: + method: stack + dims: [x, y] + target_output_variable: forcing + +extra: + projection: + class_name: PlateCarree + kwargs: + central_longitude: 0.0 diff --git a/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/config.yaml b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/config.yaml new file mode 100644 index 00000000..27cc9764 --- /dev/null +++ b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/config.yaml @@ -0,0 +1,18 @@ +datastore: + kind: npyfilesmeps + config_path: meps_example_reduced.datastore.yaml +datastore_boundary: + kind: mdp + config_path: era5.datastore.yaml +training: + state_feature_weighting: + __config_class__: ManualStateFeatureWeighting + weights: + nlwrs_0: 1.0 + nswrs_0: 1.0 + pres_0g: 1.0 + pres_0s: 1.0 + r_2: 1.0 + r_65: 1.0 + t_2: 1.0 + t_65: 1.0 diff --git a/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml new file mode 100644 index 00000000..7c5ffb3b --- /dev/null +++ b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml @@ -0,0 +1,85 @@ +schema_version: v0.5.0 +dataset_version: v1.0.0 + +output: + variables: + forcing: [time, grid_index, forcing_feature] + coord_ranges: + time: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + step: PT6H + chunking: + time: 1 + splitting: + dim: time + splits: + train: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + compute_statistics: + ops: [mean, std, diff_mean, diff_std] + dims: [grid_index, time] + val: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + test: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + +inputs: + era_height_levels: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr' + dims: [time, longitude, latitude, level] + variables: + u_component_of_wind: + level: + values: [1000,] + units: hPa + dim_mapping: + time: + method: rename + dim: time + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + forcing_feature: + method: stack_variables_by_var_name + dims: [level] + name_format: "{var_name}{level}hPa" + grid_index: + method: stack + dims: [x, y] + target_output_variable: forcing + + era5_surface: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr' + dims: [time, longitude, latitude, level] + variables: + - mean_surface_net_short_wave_radiation_flux + dim_mapping: + time: + method: rename + dim: time + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + forcing_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + grid_index: + method: stack + dims: [x, y] + target_output_variable: forcing + +extra: + projection: + class_name: PlateCarree + kwargs: + central_longitude: 0.0 diff --git a/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/meps_example_reduced.datastore.yaml b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/meps_example_reduced.datastore.yaml new file mode 100644 index 00000000..3d88d4a4 --- /dev/null +++ b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/meps_example_reduced.datastore.yaml @@ -0,0 +1,44 @@ +dataset: + name: meps_example_reduced + num_forcing_features: 16 + var_longnames: + - pres_heightAboveGround_0_instant + - pres_heightAboveSea_0_instant + - nlwrs_heightAboveGround_0_accum + - nswrs_heightAboveGround_0_accum + - r_heightAboveGround_2_instant + - r_hybrid_65_instant + - t_heightAboveGround_2_instant + - t_hybrid_65_instant + var_names: + - pres_0g + - pres_0s + - nlwrs_0 + - nswrs_0 + - r_2 + - r_65 + - t_2 + - t_65 + var_units: + - Pa + - Pa + - W/m**2 + - W/m**2 + - '' + - '' + - K + - K + num_timesteps: 65 + num_ensemble_members: 2 + step_length: 3 +grid_shape_state: +- 134 +- 119 +projection: + class_name: LambertConformal + kwargs: + central_latitude: 63.3 + central_longitude: 15.0 + standard_parallels: + - 63.3 + - 63.3 diff --git a/tests/dummy_datastore.py b/tests/dummy_datastore.py index 9075d404..a958b8f5 100644 --- a/tests/dummy_datastore.py +++ b/tests/dummy_datastore.py @@ -28,7 +28,7 @@ class DummyDatastore(BaseRegularGridDatastore): """ SHORT_NAME = "dummydata" - T0 = isodate.parse_datetime("2021-01-01T00:00:00") + T0 = isodate.parse_datetime("1990-09-02T00:00:00") N_FEATURES = dict(state=5, forcing=2, static=1) CARTESIAN_COORDS = ["x", "y"] @@ -148,12 +148,6 @@ def __init__( times = [self.T0 + dt * i for i in range(n_timesteps)] self.ds.coords["time"] = times - # Add boundary mask - self.ds["boundary_mask"] = xr.DataArray( - np.random.choice([0, 1], size=(n_points_1d, n_points_1d)), - dims=["x", "y"], - ) - # Stack the spatial dimensions into grid_index self.ds = self.ds.stack(grid_index=self.CARTESIAN_COORDS) @@ -342,22 +336,6 @@ def get_dataarray( dim_order = self.expected_dim_order(category=category) return self.ds[category].transpose(*dim_order) - @cached_property - def boundary_mask(self) -> xr.DataArray: - """ - Return the boundary mask for the dataset, with spatial dimensions - stacked. Where the value is 1, the grid point is a boundary point, and - where the value is 0, the grid point is not a boundary point. - - Returns - ------- - xr.DataArray - The boundary mask for the dataset, with dimensions - `('grid_index',)`. - - """ - return self.ds["boundary_mask"] - def get_xy(self, category: str, stacked: bool) -> ndarray: """Return the x, y coordinates of the dataset. diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 419aece0..063ec147 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -14,12 +14,19 @@ from neural_lam.datastore.base import BaseRegularGridDatastore from neural_lam.models.graph_lam import GraphLAM from neural_lam.weather_dataset import WeatherDataset -from tests.conftest import init_datastore_example +from tests.conftest import ( + DATASTORES_BOUNDARY_EXAMPLES, + init_datastore_boundary_example, + init_datastore_example, +) from tests.dummy_datastore import DummyDatastore @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_dataset_item_shapes(datastore_name): +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +def test_dataset_item_shapes(datastore_name, datastore_boundary_name): """Check that the `datastore.get_dataarray` method is implemented. Validate the shapes of the tensors match between the different @@ -31,24 +38,33 @@ def test_dataset_item_shapes(datastore_name): """ datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) N_gridpoints = datastore.num_grid_points + N_gridpoints_boundary = datastore_boundary.num_grid_points N_pred_steps = 4 num_past_forcing_steps = 1 num_future_forcing_steps = 1 + num_past_boundary_steps = 1 + num_future_boundary_steps = 1 dataset = WeatherDataset( datastore=datastore, + datastore_boundary=datastore_boundary, split="train", ar_steps=N_pred_steps, num_past_forcing_steps=num_past_forcing_steps, num_future_forcing_steps=num_future_forcing_steps, + num_past_boundary_steps=num_past_boundary_steps, + num_future_boundary_steps=num_future_boundary_steps, ) item = dataset[0] # unpack the item, this is the current return signature for # WeatherDataset.__getitem__ - init_states, target_states, forcing, target_times = item + init_states, target_states, forcing, boundary, target_times = item # initial states assert init_states.ndim == 3 @@ -66,8 +82,23 @@ def test_dataset_item_shapes(datastore_name): assert forcing.ndim == 3 assert forcing.shape[0] == N_pred_steps assert forcing.shape[1] == N_gridpoints - assert forcing.shape[2] == datastore.get_num_data_vars("forcing") * ( - num_past_forcing_steps + num_future_forcing_steps + 1 + # each stacked forcing feature has one corresponding temporal embedding + assert ( + forcing.shape[2] + == datastore.get_num_data_vars("forcing") + * (num_past_forcing_steps + num_future_forcing_steps + 1) + * 2 + ) + + # boundary + assert boundary.ndim == 3 + assert boundary.shape[0] == N_pred_steps + assert boundary.shape[1] == N_gridpoints_boundary + assert ( + boundary.shape[2] + == datastore_boundary.get_num_data_vars("forcing") + * (num_past_boundary_steps + num_future_boundary_steps + 1) + * 2 ) # batch times @@ -87,8 +118,10 @@ def test_dataset_item_create_dataarray_from_tensor(datastore_name): N_pred_steps = 4 num_past_forcing_steps = 1 num_future_forcing_steps = 1 + dataset = WeatherDataset( datastore=datastore, + datastore_boundary=None, split="train", ar_steps=N_pred_steps, num_past_forcing_steps=num_past_forcing_steps, @@ -99,10 +132,14 @@ def test_dataset_item_create_dataarray_from_tensor(datastore_name): # unpack the item, this is the current return signature for # WeatherDataset.__getitem__ - _, target_states, _, target_times_arr = dataset[idx] - _, da_target_true, _, da_target_times_true = dataset._build_item_dataarrays( - idx=idx - ) + _, target_states, _, _, target_times_arr = dataset[idx] + ( + _, + da_target_true, + _, + _, + da_target_times_true, + ) = dataset._build_item_dataarrays(idx=idx) target_times = np.array(target_times_arr, dtype="datetime64[ns]") np.testing.assert_equal(target_times, da_target_times_true.values) @@ -158,13 +195,19 @@ def test_dataset_item_create_dataarray_from_tensor(datastore_name): @pytest.mark.parametrize("split", ["train", "val", "test"]) @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_single_batch(datastore_name, split): +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +def test_single_batch(datastore_name, datastore_boundary_name, split): """Check that the `datastore.get_dataarray` method is implemented. And that it returns an xarray DataArray with the correct dimensions. """ datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) device_name = ( torch.device("cuda") if torch.cuda.is_available() else "cpu" @@ -210,7 +253,9 @@ def _create_graph(): ) ) - dataset = WeatherDataset(datastore=datastore, split=split, ar_steps=2) + dataset = WeatherDataset( + datastore=datastore, datastore_boundary=datastore_boundary, split=split + ) model = GraphLAM(args=args, datastore=datastore, config=config) # noqa @@ -244,6 +289,7 @@ def test_dataset_length(dataset_config): dataset = WeatherDataset( datastore=datastore, + datastore_boundary=None, split="train", ar_steps=dataset_config["ar_steps"], num_past_forcing_steps=dataset_config["past"], diff --git a/tests/test_datastores.py b/tests/test_datastores.py index 4a4b1100..a91f6245 100644 --- a/tests/test_datastores.py +++ b/tests/test_datastores.py @@ -18,8 +18,6 @@ dataarray for the given category. - `get_dataarray` (method): Return the processed data (as a single `xr.DataArray`) for the given category and test/train/val-split. -- `boundary_mask` (property): Return the boundary mask for the dataset, - with spatial dimensions stacked. - `config` (property): Return the configuration of the datastore. In addition BaseRegularGridDatastore must have the following methods and @@ -213,25 +211,6 @@ def test_get_dataarray(datastore_name): assert n_features["train"] == n_features["val"] == n_features["test"] -@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_boundary_mask(datastore_name): - """Check that the `datastore.boundary_mask` property is implemented and - that the returned object is an xarray DataArray with the correct shape.""" - datastore = init_datastore_example(datastore_name) - da_mask = datastore.boundary_mask - - assert isinstance(da_mask, xr.DataArray) - assert set(da_mask.dims) == {"grid_index"} - assert da_mask.dtype == "int" - assert set(da_mask.values) == {0, 1} - assert da_mask.sum() > 0 - assert da_mask.sum() < da_mask.size - - if isinstance(datastore, BaseRegularGridDatastore): - grid_shape = datastore.grid_shape_state - assert datastore.boundary_mask.size == grid_shape.x * grid_shape.y - - @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) def test_get_xy_extent(datastore_name): """Check that the `datastore.get_xy_extent` method is implemented and that diff --git a/tests/test_time_slicing.py b/tests/test_time_slicing.py index 29161505..4a59c81e 100644 --- a/tests/test_time_slicing.py +++ b/tests/test_time_slicing.py @@ -78,7 +78,7 @@ def get_vars_long_names(self, category): def test_time_slicing_analysis( ar_steps, num_past_forcing_steps, num_future_forcing_steps ): - # state and forcing variables have only on dimension, `time` + # state and forcing variables have only one dimension, `time` time_values = np.datetime64("2020-01-01") + np.arange( len(ANALYSIS_STATE_VALUES) ) @@ -93,6 +93,7 @@ def test_time_slicing_analysis( dataset = WeatherDataset( datastore=datastore, + datastore_boundary=None, ar_steps=ar_steps, num_future_forcing_steps=num_future_forcing_steps, num_past_forcing_steps=num_past_forcing_steps, @@ -101,7 +102,7 @@ def test_time_slicing_analysis( sample = dataset[0] - init_states, target_states, forcing, _ = [ + init_states, target_states, forcing, _, _ = [ tensor.numpy() for tensor in sample ] @@ -130,7 +131,7 @@ def test_time_slicing_analysis( # init_states: (2, N_grid, d_features) # target_states: (ar_steps, N_grid, d_features) - # forcing: (ar_steps, N_grid, d_windowed_forcing) + # forcing: (ar_steps, N_grid, d_windowed_forcing * 2) # target_times: (ar_steps,) assert init_states.shape == (2, 1, 1) assert init_states[:, 0, 0].tolist() == expected_init_states @@ -141,6 +142,10 @@ def test_time_slicing_analysis( assert forcing.shape == ( 3, 1, - 1 + num_past_forcing_steps + num_future_forcing_steps, + # Factor 2 because each window step has a temporal embedding + (1 + num_past_forcing_steps + num_future_forcing_steps) * 2, + ) + np.testing.assert_equal( + forcing[:, 0, : num_past_forcing_steps + num_future_forcing_steps + 1], + np.array(expected_forcing_values), ) - np.testing.assert_equal(forcing[:, 0, :], np.array(expected_forcing_values)) diff --git a/tests/test_training.py b/tests/test_training.py index 1ed1847d..ca0ebf41 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -14,18 +14,33 @@ from neural_lam.datastore.base import BaseRegularGridDatastore from neural_lam.models.graph_lam import GraphLAM from neural_lam.weather_dataset import WeatherDataModule -from tests.conftest import init_datastore_example +from tests.conftest import ( + DATASTORES_BOUNDARY_EXAMPLES, + init_datastore_boundary_example, + init_datastore_example, +) @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_training(datastore_name): +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +def test_training(datastore_name, datastore_boundary_name): datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) if not isinstance(datastore, BaseRegularGridDatastore): pytest.skip( f"Skipping test for {datastore_name} as it is not a regular " "grid datastore." ) + if not isinstance(datastore_boundary, BaseRegularGridDatastore): + pytest.skip( + f"Skipping test for {datastore_boundary_name} as it is not a " + "regular grid datastore." + ) if torch.cuda.is_available(): device_name = "cuda" @@ -59,6 +74,7 @@ def test_training(datastore_name): data_module = WeatherDataModule( datastore=datastore, + datastore_boundary=datastore_boundary, ar_steps_train=3, ar_steps_eval=5, standardize=True, @@ -66,6 +82,8 @@ def test_training(datastore_name): num_workers=1, num_past_forcing_steps=1, num_future_forcing_steps=1, + num_past_boundary_steps=1, + num_future_boundary_steps=1, ) class ModelArgs: @@ -85,6 +103,8 @@ class ModelArgs: metrics_watch = [] num_past_forcing_steps = 1 num_future_forcing_steps = 1 + num_past_boundary_steps = 1 + num_future_boundary_steps = 1 model_args = ModelArgs()