diff --git a/create_parameter_weights.py b/create_parameter_weights.py index ca3a8665..e16ea707 100644 --- a/create_parameter_weights.py +++ b/create_parameter_weights.py @@ -16,26 +16,29 @@ class PaddedWeatherDataset(torch.utils.data.Dataset): - def __init__(self, base_dataset, world_size, batch_size): + def __init__( + self, base_dataset, world_size, batch_size, duplication_factor=1 + ): super().__init__() self.base_dataset = base_dataset self.world_size = world_size self.batch_size = batch_size - self.total_samples = len(base_dataset) + self.duplication_factor = duplication_factor + self.total_samples = len(base_dataset) * duplication_factor self.padded_samples = ( (self.world_size * self.batch_size) - self.total_samples ) % self.world_size - self.original_indices = list(range(self.total_samples)) + self.original_indices = ( + list(range(len(base_dataset))) * duplication_factor + ) self.padded_indices = list( range(self.total_samples, self.total_samples + self.padded_samples) ) def __getitem__(self, idx): if idx >= self.total_samples: - # Return a padded item (zeros or a repeat of the last item) - # Repeat last item return self.base_dataset[self.original_indices[-1]] - return self.base_dataset[idx] + return self.base_dataset[idx % len(self.base_dataset)] def __len__(self): return self.total_samples + self.padded_samples @@ -85,8 +88,7 @@ def setup(rank, world_size): # pylint: disable=redefined-outer-name ) -def main(rank, world_size): # pylint: disable=redefined-outer-name - """Compute the mean and standard deviation of the input data.""" +def main(): # pylint: disable=redefined-outer-name parser = ArgumentParser(description="Training arguments") parser.add_argument( "--data_config", @@ -112,22 +114,36 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name default=4, help="Number of workers in data loader (default: 4)", ) + parser.add_argument( + "--duplication_factor", + type=int, + default=10, + help="Factor to duplicate the dataset for benchmarking", + ) + parser.add_argument( + "--parallelize", + action="store_true", + help="Run the script in parallel mode", + ) args = parser.parse_args() + rank = get_rank() + world_size = get_world_size() + config_loader = config.Config.from_file(args.data_config) - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - torch.cuda.set_device(device) - else: - device = torch.device("cpu") + if args.parallelize: + setup(rank, world_size) + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + else: + device = torch.device("cpu") if rank == 0: static_dir_path = os.path.join( "data", config_loader.dataset.name, "static" ) - - # Create parameter weights based on height w_dict = { "2": 1.0, "0": 0.1, @@ -148,7 +164,6 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name w_list.astype("float32"), ) - # Load dataset without any subsampling ds = WeatherDataset( config_loader.dataset.name, split="train", @@ -156,16 +171,25 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name pred_length=63, standardize=False, ) - ds = PaddedWeatherDataset(ds, world_size, args.batch_size) + if args.parallelize: + ds = PaddedWeatherDataset( + ds, + world_size, + args.batch_size, + duplication_factor=args.duplication_factor, + ) - train_sampler = DistributedSampler(ds, num_replicas=world_size, rank=rank) + sampler = DistributedSampler(ds, num_replicas=world_size, rank=rank) + else: + sampler = None loader = torch.utils.data.DataLoader( ds, args.batch_size, shuffle=False, num_workers=args.n_workers, - sampler=train_sampler, + sampler=sampler, ) + # Compute mean and std.-dev. of each parameter (+ flux forcing) across # full dataset if rank == 0: @@ -175,61 +199,60 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name flux_means = [] flux_squares = [] - for i in range(100): - # Data loading and initial computations remain on GPU - for init_batch, target_batch, forcing_batch in tqdm(loader): + for init_batch, target_batch, forcing_batch in tqdm(loader): + if args.parallelize: init_batch, target_batch, forcing_batch = ( init_batch.to(device), target_batch.to(device), forcing_batch.to(device), ) - batch = torch.cat((init_batch, target_batch), dim=1) - # Move to CPU after computation - means.append(torch.mean(batch, dim=(1, 2)).cpu()) - # Move to CPU after computation - squares.append(torch.mean(batch**2, dim=(1, 2)).cpu()) - flux_batch = forcing_batch[:, :, :, 1] - # Move to CPU after computation - flux_means.append(torch.mean(flux_batch).cpu()) - # Move to CPU after computation - flux_squares.append(torch.mean(flux_batch**2).cpu()) - - means_gathered = [None] * world_size - squares_gathered = [None] * world_size - # Aggregation remains unchanged but ensures inputs are on CPU - dist.all_gather_object(means_gathered, torch.cat(means, dim=0)) - dist.all_gather_object(squares_gathered, torch.cat(squares, dim=0)) - + batch = torch.cat((init_batch, target_batch), dim=1) + means.append(torch.mean(batch, dim=(1, 2)).cpu()) + squares.append(torch.mean(batch**2, dim=(1, 2)).cpu()) + flux_batch = forcing_batch[:, :, :, 1] + flux_means.append(torch.mean(flux_batch).cpu()) + flux_squares.append(torch.mean(flux_batch**2).cpu()) + + if args.parallelize: + means_gathered = [None] * world_size + squares_gathered = [None] * world_size + dist.all_gather_object(means_gathered, torch.cat(means, dim=0)) + dist.all_gather_object(squares_gathered, torch.cat(squares, dim=0)) + if rank == 0: + means_all = torch.cat(means_gathered, dim=0) + squares_all = torch.cat(squares_gathered, dim=0) + original_indices = ds.get_original_indices() + means = [means_all[i] for i in original_indices] + squares = [squares_all[i] for i in original_indices] if rank == 0: - # Final computations and saving are done on CPU - means_all = torch.cat(means_gathered, dim=0) - squares_all = torch.cat(squares_gathered, dim=0) - original_indices = ds.get_original_indices() - means_filtered = [means_all[i] for i in original_indices] - squares_filtered = [squares_all[i] for i in original_indices] - mean = torch.mean(torch.stack(means_filtered), dim=0) - second_moment = torch.mean(torch.stack(squares_filtered), dim=0) + if len(means) > 1: + means = torch.stack(means) + squares = torch.stack(squares) + else: + means = means[0] + squares = squares[0] + mean = torch.mean(means, dim=0) + second_moment = torch.mean(squares, dim=0) std = torch.sqrt(second_moment - mean**2) torch.save( mean.cpu(), os.path.join(static_dir_path, "parameter_mean.pt") - ) # Ensure tensor is on CPU - torch.save( - std.cpu(), os.path.join(static_dir_path, "parameter_std.pt") - ) # Ensure tensor is on CPU - - # flux_means_filtered = [flux_means[i] for i in original_indices] - # flux_squares_filtered = [flux_squares[i] for i in original_indices] - flux_means_all = torch.stack(flux_means) - flux_squares_all = torch.stack(flux_squares) + ) + torch.save(std.cpu(), os.path.join(static_dir_path, "parameter_std.pt")) + if len(flux_means) > 1: + flux_means_all = torch.stack(flux_means) + flux_squares_all = torch.stack(flux_squares) + else: + flux_means_all = flux_means[0] + flux_squares_all = flux_squares[0] flux_mean = torch.mean(flux_means_all) flux_second_moment = torch.mean(flux_squares_all) flux_std = torch.sqrt(flux_second_moment - flux_mean**2) torch.save( torch.stack((flux_mean, flux_std)).cpu(), os.path.join(static_dir_path, "flux_stats.pt"), - ) # Ensure tensor is on CPU - # Compute mean and std.-dev. of one-step differences across the dataset - dist.barrier() + ) + if args.parallelize: + dist.barrier() if rank == 0: print("Computing mean and std.-dev. for one-step differences...") ds_standard = WeatherDataset( @@ -239,11 +262,19 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name pred_length=63, standardize=True, ) - ds_standard = PaddedWeatherDataset(ds_standard, world_size, args.batch_size) + if args.parallelize: + ds_standard = PaddedWeatherDataset( + ds_standard, + world_size, + args.batch_size, + duplication_factor=args.duplication_factor, + ) - sampler_standard = DistributedSampler( - ds_standard, num_replicas=world_size, rank=rank - ) + sampler_standard = DistributedSampler( + ds_standard, num_replicas=world_size, rank=rank + ) + else: + sampler_standard = None loader_standard = torch.utils.data.DataLoader( ds_standard, args.batch_size, @@ -257,9 +288,10 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name diff_squares = [] for init_batch, target_batch, _ in tqdm(loader_standard, disable=rank != 0): - init_batch, target_batch = init_batch.to(device), target_batch.to( - device - ) + if args.parallelize: + init_batch, target_batch = init_batch.to(device), target_batch.to( + device + ) batch = torch.cat((init_batch, target_batch), dim=1) stepped_batch = torch.cat( [ @@ -270,41 +302,44 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name ) batch_diffs = stepped_batch[:, 1:] - stepped_batch[:, :-1] - # Compute means and squares on GPU, then move to CPU for storage diff_means.append(torch.mean(batch_diffs, dim=(1, 2)).cpu()) diff_squares.append(torch.mean(batch_diffs**2, dim=(1, 2)).cpu()) - dist.barrier() + if args.parallelize: + dist.barrier() - diff_means_gathered = [None] * world_size - diff_squares_gathered = [None] * world_size - dist.all_gather_object(diff_means_gathered, torch.cat(diff_means, dim=0)) - dist.all_gather_object( - diff_squares_gathered, torch.cat(diff_squares, dim=0) - ) + diff_means_gathered = [None] * world_size + diff_squares_gathered = [None] * world_size + dist.all_gather_object( + diff_means_gathered, torch.cat(diff_means, dim=0) + ) + dist.all_gather_object( + diff_squares_gathered, torch.cat(diff_squares, dim=0) + ) + if rank == 0: + diff_means_all = torch.cat(diff_means_gathered, dim=0) + diff_squares_all = torch.cat(diff_squares_gathered, dim=0) + original_indices = ds_standard.get_original_indices() + diff_means = [diff_means_all[i] for i in original_indices] + diff_squares = [diff_squares_all[i] for i in original_indices] if rank == 0: - # Concatenate and compute final statistics on CPU - diff_means_all = torch.cat(diff_means_gathered, dim=0) - diff_squares_all = torch.cat(diff_squares_gathered, dim=0) - original_indices = ds_standard.get_original_indices() - diff_means_filtered = [diff_means_all[i] for i in original_indices] - diff_squares_filtered = [diff_squares_all[i] for i in original_indices] - diff_mean = torch.mean(torch.stack(diff_means_filtered), dim=0) - diff_second_moment = torch.mean( - torch.stack(diff_squares_filtered), dim=0 - ) + if len(diff_means) > 1: + diff_means = torch.stack(diff_means) + diff_squares = torch.stack(diff_squares) + else: + diff_means = diff_means[0] + diff_squares = diff_squares[0] + diff_mean = torch.mean(diff_means, dim=0) + diff_second_moment = torch.mean(diff_squares, dim=0) diff_std = torch.sqrt(diff_second_moment - diff_mean**2) - # Save tensors to disk, ensuring they are on CPU torch.save(diff_mean, os.path.join(static_dir_path, "diff_mean.pt")) torch.save(diff_std, os.path.join(static_dir_path, "diff_std.pt")) - dist.destroy_process_group() + if args.parallelize: + dist.destroy_process_group() if __name__ == "__main__": - rank = get_rank() - world_size = get_world_size() - setup(rank, world_size) - main(rank, world_size) + main()