Skip to content

Commit

Permalink
fixed gpu/cpu single-node
Browse files Browse the repository at this point in the history
  • Loading branch information
sadamov committed May 30, 2024
1 parent 644025f commit c9f8b50
Showing 1 changed file with 137 additions and 106 deletions.
243 changes: 137 additions & 106 deletions create_parameter_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,35 @@
from neural_lam.weather_dataset import WeatherDataset


class PaddedWeatherDataset(torch.utils.data.Dataset):
def __init__(self, base_dataset, world_size, batch_size):
super().__init__()
self.base_dataset = base_dataset
self.world_size = world_size
self.batch_size = batch_size
self.total_samples = len(base_dataset)
self.padded_samples = (
(self.world_size * self.batch_size) - self.total_samples
) % self.world_size
self.original_indices = list(range(self.total_samples))
self.padded_indices = list(
range(self.total_samples, self.total_samples + self.padded_samples)
)

def __getitem__(self, idx):
if idx >= self.total_samples:
# Return a padded item (zeros or a repeat of the last item)
# Repeat last item
return self.base_dataset[self.original_indices[-1]]
return self.base_dataset[idx]

def __len__(self):
return self.total_samples + self.padded_samples

def get_original_indices(self):
return self.original_indices


def get_rank():
"""Get the rank of the current process in the distributed group."""
if "SLURM_PROCID" in os.environ:
Expand All @@ -31,7 +60,7 @@ def get_world_size():

def setup(rank, world_size): # pylint: disable=redefined-outer-name
"""Initialize the distributed group."""
try:
if "SLURM_JOB_NODELIST" in os.environ:
master_node = (
subprocess.check_output(
"scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1",
Expand All @@ -40,9 +69,8 @@ def setup(rank, world_size): # pylint: disable=redefined-outer-name
.strip()
.decode("utf-8")
)
except Exception as e:
print(f"Error getting master node IP: {e}")
raise
else:
master_node = "localhost"
master_port = "12355"
os.environ["MASTER_ADDR"] = master_node
os.environ["MASTER_PORT"] = master_port
Expand All @@ -57,32 +85,8 @@ def setup(rank, world_size): # pylint: disable=redefined-outer-name
)


def cleanup():
"""Destroy the distributed group."""
dist.destroy_process_group()


def adjust_dataset_size(ds, world_size, batch_size):
# pylint: disable=redefined-outer-name
"""Adjust the dataset size to be divisible by world_size * batch_size."""
total_samples = len(ds)
subset_samples = (total_samples // (world_size * batch_size)) * (
world_size * batch_size
)

if subset_samples != total_samples:
ds = torch.utils.data.Subset(ds, range(subset_samples))
print(
f"Dataset size adjusted from {total_samples} to "
f"{subset_samples} to be divisible by (world_size * batch_size)."
)

return ds


def main(rank, world_size): # pylint: disable=redefined-outer-name
"""Compute the mean and standard deviation of the input data."""
setup(rank, world_size)
parser = ArgumentParser(description="Training arguments")
parser.add_argument(
"--data_config",
Expand Down Expand Up @@ -111,34 +115,38 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name
args = parser.parse_args()

config_loader = config.Config.from_file(args.data_config)
device = torch.device(
f"cuda:{rank % torch.cuda.device_count()}"
if torch.cuda.is_available()
else "cpu"
)
static_dir_path = os.path.join("data", config_loader.dataset.name, "static")

# Create parameter weights based on height
# based on fig A.1 in graph cast paper
w_dict = {
"2": 1.0,
"0": 0.1,
"65": 0.065,
"1000": 0.1,
"850": 0.05,
"500": 0.03,
}
w_list = np.array(
[
w_dict[par.split("_")[-2]]
for par in config_loader.dataset.var_longnames
]
)
print("Saving parameter weights...")
np.save(
os.path.join(static_dir_path, "parameter_weights.npy"),
w_list.astype("float32"),
)

if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")

if rank == 0:
static_dir_path = os.path.join(
"data", config_loader.dataset.name, "static"
)

# Create parameter weights based on height
w_dict = {
"2": 1.0,
"0": 0.1,
"65": 0.065,
"1000": 0.1,
"850": 0.05,
"500": 0.03,
}
w_list = np.array(
[
w_dict[par.split("_")[-2]]
for par in config_loader.dataset.var_longnames
]
)
print("Saving parameter weights...")
np.save(
os.path.join(static_dir_path, "parameter_weights.npy"),
w_list.astype("float32"),
)

# Load dataset without any subsampling
ds = WeatherDataset(
Expand All @@ -147,9 +155,8 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name
subsample_step=1,
pred_length=63,
standardize=False,
) # Without standardization

ds = adjust_dataset_size(ds, world_size, args.batch_size)
)
ds = PaddedWeatherDataset(ds, world_size, args.batch_size)

train_sampler = DistributedSampler(ds, num_replicas=world_size, rank=rank)
loader = torch.utils.data.DataLoader(
Expand All @@ -159,62 +166,80 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name
num_workers=args.n_workers,
sampler=train_sampler,
)
# Compute mean and std.-dev. of each parameter (+ flux forcing)
# across full dataset
print("Computing mean and std.-dev. for parameters...")
# Compute mean and std.-dev. of each parameter (+ flux forcing) across
# full dataset
if rank == 0:
print("Computing mean and std.-dev. for parameters...")
means = []
squares = []
flux_means = []
flux_squares = []
for init_batch, target_batch, forcing_batch in tqdm(loader):
batch = torch.cat((init_batch, target_batch), dim=1).to(
device
) # (N_batch, N_t, N_grid, d_features)
means.append(torch.mean(batch, dim=(1, 2))) # (N_batch, d_features,)
squares.append(
torch.mean(batch**2, dim=(1, 2))
) # (N_batch, d_features,)

# Flux at 1st windowed position is index 1 in forcing
flux_batch = forcing_batch[:, :, :, 1]
flux_means.append(torch.mean(flux_batch)) # (,)
flux_squares.append(torch.mean(flux_batch**2)) # (,)
dist.barrier()

for i in range(100):
# Data loading and initial computations remain on GPU
for init_batch, target_batch, forcing_batch in tqdm(loader):
init_batch, target_batch, forcing_batch = (
init_batch.to(device),
target_batch.to(device),
forcing_batch.to(device),
)
batch = torch.cat((init_batch, target_batch), dim=1)
# Move to CPU after computation
means.append(torch.mean(batch, dim=(1, 2)).cpu())
# Move to CPU after computation
squares.append(torch.mean(batch**2, dim=(1, 2)).cpu())
flux_batch = forcing_batch[:, :, :, 1]
# Move to CPU after computation
flux_means.append(torch.mean(flux_batch).cpu())
# Move to CPU after computation
flux_squares.append(torch.mean(flux_batch**2).cpu())

means_gathered = [None] * world_size
squares_gathered = [None] * world_size
# Aggregation remains unchanged but ensures inputs are on CPU
dist.all_gather_object(means_gathered, torch.cat(means, dim=0))
dist.all_gather_object(squares_gathered, torch.cat(squares, dim=0))

if rank == 0:
# Final computations and saving are done on CPU
means_all = torch.cat(means_gathered, dim=0)
squares_all = torch.cat(squares_gathered, dim=0)
mean = torch.mean(means_all, dim=0)
second_moment = torch.mean(squares_all, dim=0)
original_indices = ds.get_original_indices()
means_filtered = [means_all[i] for i in original_indices]
squares_filtered = [squares_all[i] for i in original_indices]
mean = torch.mean(torch.stack(means_filtered), dim=0)
second_moment = torch.mean(torch.stack(squares_filtered), dim=0)
std = torch.sqrt(second_moment - mean**2)
torch.save(mean, os.path.join(static_dir_path, "parameter_mean.pt"))
torch.save(std, os.path.join(static_dir_path, "parameter_std.pt"))
torch.save(
mean.cpu(), os.path.join(static_dir_path, "parameter_mean.pt")
) # Ensure tensor is on CPU
torch.save(
std.cpu(), os.path.join(static_dir_path, "parameter_std.pt")
) # Ensure tensor is on CPU

# flux_means_filtered = [flux_means[i] for i in original_indices]
# flux_squares_filtered = [flux_squares[i] for i in original_indices]
flux_means_all = torch.stack(flux_means)
flux_squares_all = torch.stack(flux_squares)
flux_mean = torch.mean(flux_means_all)
flux_second_moment = torch.mean(flux_squares_all)
flux_std = torch.sqrt(flux_second_moment - flux_mean**2)
torch.save(
{"mean": flux_mean, "std": flux_std},
torch.stack((flux_mean, flux_std)).cpu(),
os.path.join(static_dir_path, "flux_stats.pt"),
)
) # Ensure tensor is on CPU
# Compute mean and std.-dev. of one-step differences across the dataset
print("Computing mean and std.-dev. for one-step differences...")
dist.barrier()
if rank == 0:
print("Computing mean and std.-dev. for one-step differences...")
ds_standard = WeatherDataset(
config_loader.dataset.name,
split="train",
subsample_step=1,
pred_length=63,
standardize=True,
) # Re-load with standardization

ds_standard = adjust_dataset_size(ds_standard, world_size, args.batch_size)
)
ds_standard = PaddedWeatherDataset(ds_standard, world_size, args.batch_size)

sampler_standard = DistributedSampler(
ds_standard, num_replicas=world_size, rank=rank
Expand All @@ -232,27 +257,22 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name
diff_squares = []

for init_batch, target_batch, _ in tqdm(loader_standard, disable=rank != 0):
batch = torch.cat((init_batch, target_batch), dim=1).to(device)
# Note: batch contains only 1h-steps
init_batch, target_batch = init_batch.to(device), target_batch.to(
device
)
batch = torch.cat((init_batch, target_batch), dim=1)
stepped_batch = torch.cat(
[
batch[:, ss_i : used_subsample_len : args.step_length]
for ss_i in range(args.step_length)
],
dim=0,
)
# (N_batch', N_t, N_grid, d_features),
# N_batch' = args.step_length*N_batch

batch_diffs = stepped_batch[:, 1:] - stepped_batch[:, :-1]
# (N_batch', N_t-1, N_grid, d_features)

diff_means.append(
torch.mean(batch_diffs, dim=(1, 2))
) # (N_batch', d_features,)
diff_squares.append(
torch.mean(batch_diffs**2, dim=(1, 2))
) # (N_batch', d_features,)
# Compute means and squares on GPU, then move to CPU for storage
diff_means.append(torch.mean(batch_diffs, dim=(1, 2)).cpu())
diff_squares.append(torch.mean(batch_diffs**2, dim=(1, 2)).cpu())

dist.barrier()

Expand All @@ -262,18 +282,29 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name
dist.all_gather_object(
diff_squares_gathered, torch.cat(diff_squares, dim=0)
)
diff_means_all = torch.cat(diff_means_gathered, dim=0)
diff_squares_all = torch.cat(diff_squares_gathered, dim=0)
diff_mean = torch.mean(diff_means_all, dim=0)
diff_second_moment = torch.mean(diff_squares_all, dim=0)
diff_std = torch.sqrt(diff_second_moment - diff_mean**2)
torch.save(diff_mean, os.path.join(static_dir_path, "diff_mean.pt"))
torch.save(diff_std, os.path.join(static_dir_path, "diff_std.pt"))

cleanup()
if rank == 0:
# Concatenate and compute final statistics on CPU
diff_means_all = torch.cat(diff_means_gathered, dim=0)
diff_squares_all = torch.cat(diff_squares_gathered, dim=0)
original_indices = ds_standard.get_original_indices()
diff_means_filtered = [diff_means_all[i] for i in original_indices]
diff_squares_filtered = [diff_squares_all[i] for i in original_indices]
diff_mean = torch.mean(torch.stack(diff_means_filtered), dim=0)
diff_second_moment = torch.mean(
torch.stack(diff_squares_filtered), dim=0
)
diff_std = torch.sqrt(diff_second_moment - diff_mean**2)

# Save tensors to disk, ensuring they are on CPU
torch.save(diff_mean, os.path.join(static_dir_path, "diff_mean.pt"))
torch.save(diff_std, os.path.join(static_dir_path, "diff_std.pt"))

dist.destroy_process_group()


if __name__ == "__main__":
rank = get_rank()
world_size = get_world_size()
setup(rank, world_size)
main(rank, world_size)

0 comments on commit c9f8b50

Please sign in to comment.