Skip to content

Commit

Permalink
Fix swapped x and y dimensions in comments and variable names
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Jun 3, 2024
1 parent e5400bb commit c2a0060
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
6 changes: 3 additions & 3 deletions create_grid_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ def main():
# -- Static grid node features --
grid_xy = torch.tensor(
np.load(os.path.join(static_dir_path, "nwp_xy.npy"))
) # (2, N_x, N_y)
) # (2, N_y, N_x)
grid_xy = grid_xy.flatten(1, 2).T # (N_grid, 2)
pos_max = torch.max(torch.abs(grid_xy))
grid_xy = grid_xy / pos_max # Divide by maximum coordinate

geopotential = torch.tensor(
np.load(os.path.join(static_dir_path, "surface_geopotential.npy"))
) # (N_x, N_y)
) # (N_y, N_x)
geopotential = geopotential.flatten(0, 1).unsqueeze(1) # (N_grid,1)
gp_min = torch.min(geopotential)
gp_max = torch.max(geopotential)
Expand All @@ -46,7 +46,7 @@ def main():
grid_border_mask = torch.tensor(
np.load(os.path.join(static_dir_path, "border_mask.npy")),
dtype=torch.int64,
) # (N_x, N_y)
) # (N_y, N_x)
grid_border_mask = (
grid_border_mask.flatten(0, 1).to(torch.float).unsqueeze(1)
) # (N_grid, 1)
Expand Down
28 changes: 14 additions & 14 deletions neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class WeatherDataset(torch.utils.data.Dataset):
For our dataset:
N_t' = 65
N_t = 65//subsample_step (= 21 for 3h steps)
dim_x = 268
dim_y = 238
dim_y = 268
dim_x = 238
N_grid = 268x238 = 63784
d_features = 17 (d_features' = 18)
d_forcing = 5
Expand Down Expand Up @@ -87,7 +87,7 @@ def __getitem__(self, idx):
try:
full_sample = torch.tensor(
np.load(sample_path), dtype=torch.float32
) # (N_t', dim_x, dim_y, d_features')
) # (N_t', dim_y, dim_x, d_features')
except ValueError:
print(f"Failed to load {sample_path}")

Expand All @@ -101,40 +101,40 @@ def __getitem__(self, idx):
sample = full_sample[
subsample_index : subsample_end_index : self.subsample_step
]
# (N_t, dim_x, dim_y, d_features')
# (N_t, dim_y, dim_x, d_features')

# Remove feature 15, "z_height_above_ground"
sample = torch.cat(
(sample[:, :, :, :15], sample[:, :, :, 16:]), dim=3
) # (N_t, dim_x, dim_y, d_features)
) # (N_t, dim_y, dim_x, d_features)

# Accumulate solar radiation instead of just subsampling
rad_features = full_sample[:, :, :, 2:4] # (N_t', dim_x, dim_y, 2)
rad_features = full_sample[:, :, :, 2:4] # (N_t', dim_y, dim_x, 2)
# Accumulate for first time step
init_accum_rad = torch.sum(
rad_features[: (subsample_index + 1)], dim=0, keepdim=True
) # (1, dim_x, dim_y, 2)
) # (1, dim_y, dim_x, 2)
# Accumulate for rest of subsampled sequence
in_subsample_len = (
subsample_end_index - self.subsample_step + subsample_index + 1
)
rad_features_in_subsample = rad_features[
(subsample_index + 1) : in_subsample_len
] # (N_t*, dim_x, dim_y, 2), N_t* = (N_t-1)*ss_step
_, dim_x, dim_y, _ = sample.shape
] # (N_t*, dim_y, dim_x, 2), N_t* = (N_t-1)*ss_step
_, dim_y, dim_x, _ = sample.shape
rest_accum_rad = torch.sum(
rad_features_in_subsample.view(
self.original_sample_length - 1,
self.subsample_step,
dim_x,
dim_y,
dim_x,
2,
),
dim=1,
) # (N_t-1, dim_x, dim_y, 2)
) # (N_t-1, dim_y, dim_x, 2)
accum_rad = torch.cat(
(init_accum_rad, rest_accum_rad), dim=0
) # (N_t, dim_x, dim_y, 2)
) # (N_t, dim_y, dim_x, 2)
# Replace in sample
sample[:, :, :, 2:4] = accum_rad

Expand Down Expand Up @@ -168,7 +168,7 @@ def __getitem__(self, idx):
np.load(water_path), dtype=torch.float32
).unsqueeze(
-1
) # (dim_x, dim_y, 1)
) # (dim_y, dim_x, 1)
# Flatten
water_cover_features = water_cover_features.flatten(0, 1) # (N_grid, 1)
# Expand over temporal dimension
Expand All @@ -183,7 +183,7 @@ def __getitem__(self, idx):
)
flux = torch.tensor(np.load(flux_path), dtype=torch.float32).unsqueeze(
-1
) # (N_t', dim_x, dim_y, 1)
) # (N_t', dim_y, dim_x, 1)

if self.standardize:
flux = (flux - self.flux_mean) / self.flux_std
Expand Down

0 comments on commit c2a0060

Please sign in to comment.