diff --git a/create_mesh.py b/create_mesh.py index 8b547166..36b9b0b5 100644 --- a/create_mesh.py +++ b/create_mesh.py @@ -197,7 +197,7 @@ def main(): graph_dir_path = os.path.join("graphs", args.graph) os.makedirs(graph_dir_path, exist_ok=True) - xy = config_loader.get_xy("static") + xy = config_loader.get_xy("static") # (2, N_y, N_x) grid_xy = torch.tensor(xy) pos_max = torch.max(torch.abs(grid_xy)) diff --git a/neural_lam/config.py b/neural_lam/config.py index d411bd1e..20185563 100644 --- a/neural_lam/config.py +++ b/neural_lam/config.py @@ -113,7 +113,6 @@ def open_zarr(self, category): return None def stack_grid(self, dataset): - """Stack the grid dimensions of the dataset.""" if dataset is None: return None dims = dataset.to_array().dims @@ -124,7 +123,7 @@ def stack_grid(self, dataset): else: if "x" not in dims or "y" not in dims: self.rename_dataset_dims_and_vars(dataset=dataset) - dataset = dataset.squeeze().stack(grid=("x", "y")) + dataset = dataset.squeeze().stack(grid=("y", "x")) return dataset def convert_dataset_to_dataarray(self, dataset): @@ -201,7 +200,7 @@ def filter_dimensions(self, dataset, transpose_array=True): return dataset def reshape_grid_to_2d(self, dataset, grid_shape=None): - """Reshape the grid to 2D.""" + """Reshape the grid to 2D for stacked data without multi-index.""" if grid_shape is None: grid_shape = dict(self.grid_shape_state.values.items()) x_dim, y_dim = (grid_shape["x"], grid_shape["y"]) @@ -209,7 +208,7 @@ def reshape_grid_to_2d(self, dataset, grid_shape=None): x_coords = np.arange(x_dim) y_coords = np.arange(y_dim) multi_index = pd.MultiIndex.from_product( - [x_coords, y_coords], names=["x", "y"] + [y_coords, x_coords], names=["y", "x"] ) mindex_coords = xr.Coordinates.from_pandas_multiindex( @@ -227,8 +226,8 @@ def get_xy(self, category): dataset = self.open_zarr(category) x, y = dataset.x.values, dataset.y.values if x.ndim == 1: - x, y = np.meshgrid(y, x) - xy = np.stack((x, y), axis=0) + x, y = np.meshgrid(x, y) + xy = np.stack((x, y), axis=0) # (2, N_y, N_x) return xy diff --git a/plot_graph.py b/plot_graph.py index 2c3f6238..dc3682ff 100644 --- a/plot_graph.py +++ b/plot_graph.py @@ -45,7 +45,7 @@ def main(): args = parser.parse_args() config_loader = config.Config.from_file(args.data_config) - xy = config_loader.get_xy("state") # (2, N_x, N_y) + xy = config_loader.get_xy("state") # (2, N_y, N_x) xy = xy.reshape(2, -1).T # (N_grid, 2) pos_max = np.max(np.abs(xy)) grid_pos = xy / pos_max # Divide by maximum coordinate