Skip to content

Commit

Permalink
Implement on_batch_transfer logic to normalize data
Browse files Browse the repository at this point in the history
  • Loading branch information
sadamov committed May 25, 2024
1 parent 5b71be3 commit c014222
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 63 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ data
│ ├── parameter_std.pt - Std.-dev. of state parameters (create_parameter_weights.py)
│ ├── diff_mean.pt - Means of one-step differences (create_parameter_weights.py)
│ ├── diff_std.pt - Std.-dev. of one-step differences (create_parameter_weights.py)
│ ├── flux_stats.pt - Mean and std.-dev. of solar flux forcing (create_parameter_weights.py)
│ ├── flux_mean.pt - Mean of solar flux forcing (create_parameter_weights.py)
│ ├── flux_std.pt - Std.-dev. of solar flux forcing (create_parameter_weights.py)
│ └── parameter_weights.npy - Loss weights for different state parameters (create_parameter_weights.py)
├── dataset2
├── ...
Expand Down
24 changes: 9 additions & 15 deletions create_parameter_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def main():
split="train",
subsample_step=1,
pred_length=63,
standardize=False,
) # Without standardization
loader = torch.utils.data.DataLoader(
ds, args.batch_size, shuffle=False, num_workers=args.n_workers
Expand Down Expand Up @@ -107,30 +106,25 @@ def main():
flux_mean = torch.mean(torch.stack(flux_means)) # (,)
flux_second_moment = torch.mean(torch.stack(flux_squares)) # (,)
flux_std = torch.sqrt(flux_second_moment - flux_mean**2) # (,)
flux_stats = torch.stack((flux_mean, flux_std))

print("Saving mean, std.-dev, flux_stats...")
print("Saving mean, std.-dev, flux_mean, flux_std...")
torch.save(mean, os.path.join(static_dir_path, "parameter_mean.pt"))
torch.save(std, os.path.join(static_dir_path, "parameter_std.pt"))
torch.save(flux_stats, os.path.join(static_dir_path, "flux_stats.pt"))
torch.save(flux_mean, os.path.join(static_dir_path, "flux_mean.pt"))
torch.save(flux_std, os.path.join(static_dir_path, "flux_std.pt"))

# Compute mean and std.-dev. of one-step differences across the dataset
print("Computing mean and std.-dev. for one-step differences...")
ds_standard = WeatherDataset(
config_loader.dataset.name,
split="train",
subsample_step=1,
pred_length=63,
standardize=True,
) # Re-load with standardization
loader_standard = torch.utils.data.DataLoader(
ds_standard, args.batch_size, shuffle=False, num_workers=args.n_workers
)
used_subsample_len = (65 // args.step_length) * args.step_length

diff_means = []
diff_squares = []
for init_batch, target_batch, _ in tqdm(loader_standard):
for init_batch, target_batch, _ in tqdm(loader):
# normalize the batch
init_batch = (init_batch - mean) / std
target_batch = (target_batch - mean) / std

batch = torch.cat((init_batch, target_batch), dim=1)
batch = torch.cat(
(init_batch, target_batch), dim=1
) # (N_batch, N_t', N_grid, d_features)
Expand Down
13 changes: 13 additions & 0 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,19 @@ def common_step(self, batch):

return prediction, target_states, pred_std

def on_after_batch_transfer(self, batch, dataloader_idx):
"""Normalize Batch data after transferring to the device."""
init_states, target_states, forcing_features = batch
init_states = (init_states - self.data_mean) / self.data_std
target_states = (target_states - self.data_mean) / self.data_std
forcing_features = (forcing_features - self.flux_mean) / self.flux_std
batch = (
init_states,
target_states,
forcing_features,
)
return batch

def training_step(self, batch):
"""
Train on single batch
Expand Down
30 changes: 5 additions & 25 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,6 @@
from tueplots import bundles, figsizes


def load_dataset_stats(dataset_name, device="cpu"):
"""
Load arrays with stored dataset statistics from pre-processing
"""
static_dir_path = os.path.join("data", dataset_name, "static")

def loads_file(fn):
return torch.load(
os.path.join(static_dir_path, fn), map_location=device
)

data_mean = loads_file("parameter_mean.pt") # (d_features,)
data_std = loads_file("parameter_std.pt") # (d_features,)

flux_stats = loads_file("flux_stats.pt") # (2,)
flux_mean, flux_std = flux_stats

return {
"data_mean": data_mean,
"data_std": data_std,
"flux_mean": flux_mean,
"flux_std": flux_std,
}


def load_static_data(dataset_name, device="cpu"):
"""
Load static files related to dataset
Expand Down Expand Up @@ -64,6 +39,9 @@ def loads_file(fn):
data_mean = loads_file("parameter_mean.pt") # (d_features,)
data_std = loads_file("parameter_std.pt") # (d_features,)

flux_mean = loads_file("flux_mean.pt") # (,)
flux_std = loads_file("flux_std.pt") # (,)

# Load loss weighting vectors
param_weights = torch.tensor(
np.load(os.path.join(static_dir_path, "parameter_weights.npy")),
Expand All @@ -78,6 +56,8 @@ def loads_file(fn):
"step_diff_std": step_diff_std,
"data_mean": data_mean,
"data_std": data_std,
"flux_mean": flux_mean,
"flux_std": flux_std,
"param_weights": param_weights,
}

Expand Down
22 changes: 0 additions & 22 deletions neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
import numpy as np
import torch

# First-party
from neural_lam import utils


class WeatherDataset(torch.utils.data.Dataset):
"""
Expand All @@ -29,7 +26,6 @@ def __init__(
pred_length=19,
split="train",
subsample_step=3,
standardize=True,
subset=False,
control_only=False,
):
Expand Down Expand Up @@ -61,17 +57,6 @@ def __init__(
self.sample_length <= self.original_sample_length
), "Requesting too long time series samples"

# Set up for standardization
self.standardize = standardize
if standardize:
ds_stats = utils.load_dataset_stats(dataset_name, "cpu")
self.data_mean, self.data_std, self.flux_mean, self.flux_std = (
ds_stats["data_mean"],
ds_stats["data_std"],
ds_stats["flux_mean"],
ds_stats["flux_std"],
)

# If subsample index should be sampled (only duing training)
self.random_subsample = split == "train"

Expand Down Expand Up @@ -148,10 +133,6 @@ def __getitem__(self, idx):
sample = sample[init_id : (init_id + self.sample_length)]
# (sample_length, N_grid, d_features)

if self.standardize:
# Standardize sample
sample = (sample - self.data_mean) / self.data_std

# Split up sample in init. states and target states
init_states = sample[:2] # (2, N_grid, d_features)
target_states = sample[2:] # (sample_length-2, N_grid, d_features)
Expand Down Expand Up @@ -185,9 +166,6 @@ def __getitem__(self, idx):
-1
) # (N_t', dim_x, dim_y, 1)

if self.standardize:
flux = (flux - self.flux_mean) / self.flux_std

# Flatten and subsample flux forcing
flux = flux.flatten(1, 2) # (N_t, N_grid, 1)
flux = flux[subsample_index :: self.subsample_step] # (N_t, N_grid, 1)
Expand Down

0 comments on commit c014222

Please sign in to comment.