Skip to content

Commit

Permalink
Fixed order of y and x dims to adhere to #52
Browse files Browse the repository at this point in the history
  • Loading branch information
sadamov committed Jun 3, 2024
1 parent 8fa3ca7 commit a748903
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
2 changes: 1 addition & 1 deletion create_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def main():
graph_dir_path = os.path.join("graphs", args.graph)
os.makedirs(graph_dir_path, exist_ok=True)

xy = config_loader.get_xy("static")
xy = config_loader.get_xy("static") # (2, N_y, N_x)
grid_xy = torch.tensor(xy)
pos_max = torch.max(torch.abs(grid_xy))

Expand Down
11 changes: 5 additions & 6 deletions neural_lam/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def open_zarr(self, category):
return None

def stack_grid(self, dataset):
"""Stack the grid dimensions of the dataset."""
if dataset is None:
return None
dims = dataset.to_array().dims
Expand All @@ -124,7 +123,7 @@ def stack_grid(self, dataset):
else:
if "x" not in dims or "y" not in dims:
self.rename_dataset_dims_and_vars(dataset=dataset)
dataset = dataset.squeeze().stack(grid=("x", "y"))
dataset = dataset.squeeze().stack(grid=("y", "x"))
return dataset

def convert_dataset_to_dataarray(self, dataset):
Expand Down Expand Up @@ -201,15 +200,15 @@ def filter_dimensions(self, dataset, transpose_array=True):
return dataset

def reshape_grid_to_2d(self, dataset, grid_shape=None):
"""Reshape the grid to 2D."""
"""Reshape the grid to 2D for stacked data without multi-index."""
if grid_shape is None:
grid_shape = dict(self.grid_shape_state.values.items())
x_dim, y_dim = (grid_shape["x"], grid_shape["y"])

x_coords = np.arange(x_dim)
y_coords = np.arange(y_dim)
multi_index = pd.MultiIndex.from_product(
[x_coords, y_coords], names=["x", "y"]
[y_coords, x_coords], names=["y", "x"]
)

mindex_coords = xr.Coordinates.from_pandas_multiindex(
Expand All @@ -227,8 +226,8 @@ def get_xy(self, category):
dataset = self.open_zarr(category)
x, y = dataset.x.values, dataset.y.values
if x.ndim == 1:
x, y = np.meshgrid(y, x)
xy = np.stack((x, y), axis=0)
x, y = np.meshgrid(x, y)
xy = np.stack((x, y), axis=0) # (2, N_y, N_x)

return xy

Expand Down
2 changes: 1 addition & 1 deletion plot_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def main():

args = parser.parse_args()
config_loader = config.Config.from_file(args.data_config)
xy = config_loader.get_xy("state") # (2, N_x, N_y)
xy = config_loader.get_xy("state") # (2, N_y, N_x)
xy = xy.reshape(2, -1).T # (N_grid, 2)
pos_max = np.max(np.abs(xy))
grid_pos = xy / pos_max # Divide by maximum coordinate
Expand Down

0 comments on commit a748903

Please sign in to comment.