Skip to content

Commit

Permalink
Fix formatting with pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Mar 17, 2024
1 parent 65f00ea commit 2d9afd1
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 14 deletions.
2 changes: 1 addition & 1 deletion neural_lam/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,5 +116,5 @@
)

# Data dimensions
GRID_FORCING_DIM = 5 * 3 + 1 # 5 feat. for 3 time-step window + 1 batch-static
GRID_FORCING_DIM = 5 * 3 + 1 # 5 feat. for 3 time-step window + 1 batch-static
GRID_STATE_DIM = 17
8 changes: 2 additions & 6 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,7 @@ def expand_to_batch(x, batch_size):
"""
return x.unsqueeze(0).expand(batch_size, -1, -1)

def predict_step(
self, prev_state, prev_prev_state, forcing
):
def predict_step(self, prev_state, prev_prev_state, forcing):
"""
Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1
prev_state: (B, num_grid_nodes, feature_dim), X_t
Expand All @@ -127,9 +125,7 @@ def predict_step(
"""
raise NotImplementedError("No prediction step implemented")

def unroll_prediction(
self, init_states, forcing_features, true_states
):
def unroll_prediction(self, init_states, forcing_features, true_states):
"""
Roll out prediction taking multiple autoregressive steps with model
init_states: (B, 2, num_grid_nodes, d_f)
Expand Down
4 changes: 1 addition & 3 deletions neural_lam/models/base_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,7 @@ def process_step(self, mesh_rep):
"""
raise NotImplementedError("process_step not implemented")

def predict_step(
self, prev_state, prev_prev_state, forcing
):
def predict_step(self, prev_state, prev_prev_state, forcing):
"""
Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1
prev_state: (B, num_grid_nodes, feature_dim), X_t
Expand Down
6 changes: 2 additions & 4 deletions neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,9 @@ def __getitem__(self, idx):
) # (dim_x, dim_y, 1)
# Flatten
water_cover_features = water_cover_features.flatten(0, 1) # (N_grid, 1)
# Exand over temporal dimension
# Expand over temporal dimension
water_cover_expanded = water_cover_features.unsqueeze(0).expand(
self.sample_length - 2, # -2 as added on after windowing
-1,
-1
self.sample_length - 2, -1, -1 # -2 as added on after windowing
) # (sample_len, N_grid, 1)

# TOA flux
Expand Down

0 comments on commit 2d9afd1

Please sign in to comment.