diff --git a/create_parameter_weights.py b/create_parameter_weights.py index a4386c4c..ca3a8665 100644 --- a/create_parameter_weights.py +++ b/create_parameter_weights.py @@ -15,6 +15,35 @@ from neural_lam.weather_dataset import WeatherDataset +class PaddedWeatherDataset(torch.utils.data.Dataset): + def __init__(self, base_dataset, world_size, batch_size): + super().__init__() + self.base_dataset = base_dataset + self.world_size = world_size + self.batch_size = batch_size + self.total_samples = len(base_dataset) + self.padded_samples = ( + (self.world_size * self.batch_size) - self.total_samples + ) % self.world_size + self.original_indices = list(range(self.total_samples)) + self.padded_indices = list( + range(self.total_samples, self.total_samples + self.padded_samples) + ) + + def __getitem__(self, idx): + if idx >= self.total_samples: + # Return a padded item (zeros or a repeat of the last item) + # Repeat last item + return self.base_dataset[self.original_indices[-1]] + return self.base_dataset[idx] + + def __len__(self): + return self.total_samples + self.padded_samples + + def get_original_indices(self): + return self.original_indices + + def get_rank(): """Get the rank of the current process in the distributed group.""" if "SLURM_PROCID" in os.environ: @@ -31,7 +60,7 @@ def get_world_size(): def setup(rank, world_size): # pylint: disable=redefined-outer-name """Initialize the distributed group.""" - try: + if "SLURM_JOB_NODELIST" in os.environ: master_node = ( subprocess.check_output( "scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1", @@ -40,9 +69,8 @@ def setup(rank, world_size): # pylint: disable=redefined-outer-name .strip() .decode("utf-8") ) - except Exception as e: - print(f"Error getting master node IP: {e}") - raise + else: + master_node = "localhost" master_port = "12355" os.environ["MASTER_ADDR"] = master_node os.environ["MASTER_PORT"] = master_port @@ -57,32 +85,8 @@ def setup(rank, world_size): # pylint: disable=redefined-outer-name ) -def cleanup(): - """Destroy the distributed group.""" - dist.destroy_process_group() - - -def adjust_dataset_size(ds, world_size, batch_size): - # pylint: disable=redefined-outer-name - """Adjust the dataset size to be divisible by world_size * batch_size.""" - total_samples = len(ds) - subset_samples = (total_samples // (world_size * batch_size)) * ( - world_size * batch_size - ) - - if subset_samples != total_samples: - ds = torch.utils.data.Subset(ds, range(subset_samples)) - print( - f"Dataset size adjusted from {total_samples} to " - f"{subset_samples} to be divisible by (world_size * batch_size)." - ) - - return ds - - def main(rank, world_size): # pylint: disable=redefined-outer-name """Compute the mean and standard deviation of the input data.""" - setup(rank, world_size) parser = ArgumentParser(description="Training arguments") parser.add_argument( "--data_config", @@ -111,34 +115,38 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name args = parser.parse_args() config_loader = config.Config.from_file(args.data_config) - device = torch.device( - f"cuda:{rank % torch.cuda.device_count()}" - if torch.cuda.is_available() - else "cpu" - ) - static_dir_path = os.path.join("data", config_loader.dataset.name, "static") - - # Create parameter weights based on height - # based on fig A.1 in graph cast paper - w_dict = { - "2": 1.0, - "0": 0.1, - "65": 0.065, - "1000": 0.1, - "850": 0.05, - "500": 0.03, - } - w_list = np.array( - [ - w_dict[par.split("_")[-2]] - for par in config_loader.dataset.var_longnames - ] - ) - print("Saving parameter weights...") - np.save( - os.path.join(static_dir_path, "parameter_weights.npy"), - w_list.astype("float32"), - ) + + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + else: + device = torch.device("cpu") + + if rank == 0: + static_dir_path = os.path.join( + "data", config_loader.dataset.name, "static" + ) + + # Create parameter weights based on height + w_dict = { + "2": 1.0, + "0": 0.1, + "65": 0.065, + "1000": 0.1, + "850": 0.05, + "500": 0.03, + } + w_list = np.array( + [ + w_dict[par.split("_")[-2]] + for par in config_loader.dataset.var_longnames + ] + ) + print("Saving parameter weights...") + np.save( + os.path.join(static_dir_path, "parameter_weights.npy"), + w_list.astype("float32"), + ) # Load dataset without any subsampling ds = WeatherDataset( @@ -147,9 +155,8 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name subsample_step=1, pred_length=63, standardize=False, - ) # Without standardization - - ds = adjust_dataset_size(ds, world_size, args.batch_size) + ) + ds = PaddedWeatherDataset(ds, world_size, args.batch_size) train_sampler = DistributedSampler(ds, num_replicas=world_size, rank=rank) loader = torch.utils.data.DataLoader( @@ -159,62 +166,80 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name num_workers=args.n_workers, sampler=train_sampler, ) - # Compute mean and std.-dev. of each parameter (+ flux forcing) - # across full dataset - print("Computing mean and std.-dev. for parameters...") + # Compute mean and std.-dev. of each parameter (+ flux forcing) across + # full dataset + if rank == 0: + print("Computing mean and std.-dev. for parameters...") means = [] squares = [] flux_means = [] flux_squares = [] - for init_batch, target_batch, forcing_batch in tqdm(loader): - batch = torch.cat((init_batch, target_batch), dim=1).to( - device - ) # (N_batch, N_t, N_grid, d_features) - means.append(torch.mean(batch, dim=(1, 2))) # (N_batch, d_features,) - squares.append( - torch.mean(batch**2, dim=(1, 2)) - ) # (N_batch, d_features,) - - # Flux at 1st windowed position is index 1 in forcing - flux_batch = forcing_batch[:, :, :, 1] - flux_means.append(torch.mean(flux_batch)) # (,) - flux_squares.append(torch.mean(flux_batch**2)) # (,) - dist.barrier() + + for i in range(100): + # Data loading and initial computations remain on GPU + for init_batch, target_batch, forcing_batch in tqdm(loader): + init_batch, target_batch, forcing_batch = ( + init_batch.to(device), + target_batch.to(device), + forcing_batch.to(device), + ) + batch = torch.cat((init_batch, target_batch), dim=1) + # Move to CPU after computation + means.append(torch.mean(batch, dim=(1, 2)).cpu()) + # Move to CPU after computation + squares.append(torch.mean(batch**2, dim=(1, 2)).cpu()) + flux_batch = forcing_batch[:, :, :, 1] + # Move to CPU after computation + flux_means.append(torch.mean(flux_batch).cpu()) + # Move to CPU after computation + flux_squares.append(torch.mean(flux_batch**2).cpu()) means_gathered = [None] * world_size squares_gathered = [None] * world_size + # Aggregation remains unchanged but ensures inputs are on CPU dist.all_gather_object(means_gathered, torch.cat(means, dim=0)) dist.all_gather_object(squares_gathered, torch.cat(squares, dim=0)) if rank == 0: + # Final computations and saving are done on CPU means_all = torch.cat(means_gathered, dim=0) squares_all = torch.cat(squares_gathered, dim=0) - mean = torch.mean(means_all, dim=0) - second_moment = torch.mean(squares_all, dim=0) + original_indices = ds.get_original_indices() + means_filtered = [means_all[i] for i in original_indices] + squares_filtered = [squares_all[i] for i in original_indices] + mean = torch.mean(torch.stack(means_filtered), dim=0) + second_moment = torch.mean(torch.stack(squares_filtered), dim=0) std = torch.sqrt(second_moment - mean**2) - torch.save(mean, os.path.join(static_dir_path, "parameter_mean.pt")) - torch.save(std, os.path.join(static_dir_path, "parameter_std.pt")) + torch.save( + mean.cpu(), os.path.join(static_dir_path, "parameter_mean.pt") + ) # Ensure tensor is on CPU + torch.save( + std.cpu(), os.path.join(static_dir_path, "parameter_std.pt") + ) # Ensure tensor is on CPU + # flux_means_filtered = [flux_means[i] for i in original_indices] + # flux_squares_filtered = [flux_squares[i] for i in original_indices] flux_means_all = torch.stack(flux_means) flux_squares_all = torch.stack(flux_squares) flux_mean = torch.mean(flux_means_all) flux_second_moment = torch.mean(flux_squares_all) flux_std = torch.sqrt(flux_second_moment - flux_mean**2) torch.save( - {"mean": flux_mean, "std": flux_std}, + torch.stack((flux_mean, flux_std)).cpu(), os.path.join(static_dir_path, "flux_stats.pt"), - ) + ) # Ensure tensor is on CPU # Compute mean and std.-dev. of one-step differences across the dataset - print("Computing mean and std.-dev. for one-step differences...") + dist.barrier() + if rank == 0: + print("Computing mean and std.-dev. for one-step differences...") ds_standard = WeatherDataset( config_loader.dataset.name, split="train", subsample_step=1, pred_length=63, standardize=True, - ) # Re-load with standardization - - ds_standard = adjust_dataset_size(ds_standard, world_size, args.batch_size) + ) + ds_standard = PaddedWeatherDataset(ds_standard, world_size, args.batch_size) sampler_standard = DistributedSampler( ds_standard, num_replicas=world_size, rank=rank @@ -232,8 +257,10 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name diff_squares = [] for init_batch, target_batch, _ in tqdm(loader_standard, disable=rank != 0): - batch = torch.cat((init_batch, target_batch), dim=1).to(device) - # Note: batch contains only 1h-steps + init_batch, target_batch = init_batch.to(device), target_batch.to( + device + ) + batch = torch.cat((init_batch, target_batch), dim=1) stepped_batch = torch.cat( [ batch[:, ss_i : used_subsample_len : args.step_length] @@ -241,18 +268,11 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name ], dim=0, ) - # (N_batch', N_t, N_grid, d_features), - # N_batch' = args.step_length*N_batch - batch_diffs = stepped_batch[:, 1:] - stepped_batch[:, :-1] - # (N_batch', N_t-1, N_grid, d_features) - diff_means.append( - torch.mean(batch_diffs, dim=(1, 2)) - ) # (N_batch', d_features,) - diff_squares.append( - torch.mean(batch_diffs**2, dim=(1, 2)) - ) # (N_batch', d_features,) + # Compute means and squares on GPU, then move to CPU for storage + diff_means.append(torch.mean(batch_diffs, dim=(1, 2)).cpu()) + diff_squares.append(torch.mean(batch_diffs**2, dim=(1, 2)).cpu()) dist.barrier() @@ -262,18 +282,29 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name dist.all_gather_object( diff_squares_gathered, torch.cat(diff_squares, dim=0) ) - diff_means_all = torch.cat(diff_means_gathered, dim=0) - diff_squares_all = torch.cat(diff_squares_gathered, dim=0) - diff_mean = torch.mean(diff_means_all, dim=0) - diff_second_moment = torch.mean(diff_squares_all, dim=0) - diff_std = torch.sqrt(diff_second_moment - diff_mean**2) - torch.save(diff_mean, os.path.join(static_dir_path, "diff_mean.pt")) - torch.save(diff_std, os.path.join(static_dir_path, "diff_std.pt")) - cleanup() + if rank == 0: + # Concatenate and compute final statistics on CPU + diff_means_all = torch.cat(diff_means_gathered, dim=0) + diff_squares_all = torch.cat(diff_squares_gathered, dim=0) + original_indices = ds_standard.get_original_indices() + diff_means_filtered = [diff_means_all[i] for i in original_indices] + diff_squares_filtered = [diff_squares_all[i] for i in original_indices] + diff_mean = torch.mean(torch.stack(diff_means_filtered), dim=0) + diff_second_moment = torch.mean( + torch.stack(diff_squares_filtered), dim=0 + ) + diff_std = torch.sqrt(diff_second_moment - diff_mean**2) + + # Save tensors to disk, ensuring they are on CPU + torch.save(diff_mean, os.path.join(static_dir_path, "diff_mean.pt")) + torch.save(diff_std, os.path.join(static_dir_path, "diff_std.pt")) + + dist.destroy_process_group() if __name__ == "__main__": rank = get_rank() world_size = get_world_size() + setup(rank, world_size) main(rank, world_size)