Skip to content

Commit

Permalink
introduced prallelize flag
Browse files Browse the repository at this point in the history
  • Loading branch information
sadamov committed May 31, 2024
1 parent c9f8b50 commit 743a52c
Showing 1 changed file with 125 additions and 90 deletions.
215 changes: 125 additions & 90 deletions create_parameter_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,29 @@


class PaddedWeatherDataset(torch.utils.data.Dataset):
def __init__(self, base_dataset, world_size, batch_size):
def __init__(
self, base_dataset, world_size, batch_size, duplication_factor=1
):
super().__init__()
self.base_dataset = base_dataset
self.world_size = world_size
self.batch_size = batch_size
self.total_samples = len(base_dataset)
self.duplication_factor = duplication_factor
self.total_samples = len(base_dataset) * duplication_factor
self.padded_samples = (
(self.world_size * self.batch_size) - self.total_samples
) % self.world_size
self.original_indices = list(range(self.total_samples))
self.original_indices = (
list(range(len(base_dataset))) * duplication_factor
)
self.padded_indices = list(
range(self.total_samples, self.total_samples + self.padded_samples)
)

def __getitem__(self, idx):
if idx >= self.total_samples:
# Return a padded item (zeros or a repeat of the last item)
# Repeat last item
return self.base_dataset[self.original_indices[-1]]
return self.base_dataset[idx]
return self.base_dataset[idx % len(self.base_dataset)]

def __len__(self):
return self.total_samples + self.padded_samples
Expand Down Expand Up @@ -85,8 +88,7 @@ def setup(rank, world_size): # pylint: disable=redefined-outer-name
)


def main(rank, world_size): # pylint: disable=redefined-outer-name
"""Compute the mean and standard deviation of the input data."""
def main(): # pylint: disable=redefined-outer-name
parser = ArgumentParser(description="Training arguments")
parser.add_argument(
"--data_config",
Expand All @@ -112,22 +114,36 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name
default=4,
help="Number of workers in data loader (default: 4)",
)
parser.add_argument(
"--duplication_factor",
type=int,
default=10,
help="Factor to duplicate the dataset for benchmarking",
)
parser.add_argument(
"--parallelize",
action="store_true",
help="Run the script in parallel mode",
)
args = parser.parse_args()

rank = get_rank()
world_size = get_world_size()

config_loader = config.Config.from_file(args.data_config)

if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
if args.parallelize:
setup(rank, world_size)
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")

if rank == 0:
static_dir_path = os.path.join(
"data", config_loader.dataset.name, "static"
)

# Create parameter weights based on height
w_dict = {
"2": 1.0,
"0": 0.1,
Expand All @@ -148,24 +164,32 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name
w_list.astype("float32"),
)

# Load dataset without any subsampling
ds = WeatherDataset(
config_loader.dataset.name,
split="train",
subsample_step=1,
pred_length=63,
standardize=False,
)
ds = PaddedWeatherDataset(ds, world_size, args.batch_size)
if args.parallelize:
ds = PaddedWeatherDataset(
ds,
world_size,
args.batch_size,
duplication_factor=args.duplication_factor,
)

train_sampler = DistributedSampler(ds, num_replicas=world_size, rank=rank)
sampler = DistributedSampler(ds, num_replicas=world_size, rank=rank)
else:
sampler = None
loader = torch.utils.data.DataLoader(
ds,
args.batch_size,
shuffle=False,
num_workers=args.n_workers,
sampler=train_sampler,
sampler=sampler,
)

# Compute mean and std.-dev. of each parameter (+ flux forcing) across
# full dataset
if rank == 0:
Expand All @@ -175,61 +199,60 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name
flux_means = []
flux_squares = []

for i in range(100):
# Data loading and initial computations remain on GPU
for init_batch, target_batch, forcing_batch in tqdm(loader):
for init_batch, target_batch, forcing_batch in tqdm(loader):
if args.parallelize:
init_batch, target_batch, forcing_batch = (
init_batch.to(device),
target_batch.to(device),
forcing_batch.to(device),
)
batch = torch.cat((init_batch, target_batch), dim=1)
# Move to CPU after computation
means.append(torch.mean(batch, dim=(1, 2)).cpu())
# Move to CPU after computation
squares.append(torch.mean(batch**2, dim=(1, 2)).cpu())
flux_batch = forcing_batch[:, :, :, 1]
# Move to CPU after computation
flux_means.append(torch.mean(flux_batch).cpu())
# Move to CPU after computation
flux_squares.append(torch.mean(flux_batch**2).cpu())

means_gathered = [None] * world_size
squares_gathered = [None] * world_size
# Aggregation remains unchanged but ensures inputs are on CPU
dist.all_gather_object(means_gathered, torch.cat(means, dim=0))
dist.all_gather_object(squares_gathered, torch.cat(squares, dim=0))

batch = torch.cat((init_batch, target_batch), dim=1)
means.append(torch.mean(batch, dim=(1, 2)).cpu())
squares.append(torch.mean(batch**2, dim=(1, 2)).cpu())
flux_batch = forcing_batch[:, :, :, 1]
flux_means.append(torch.mean(flux_batch).cpu())
flux_squares.append(torch.mean(flux_batch**2).cpu())

if args.parallelize:
means_gathered = [None] * world_size
squares_gathered = [None] * world_size
dist.all_gather_object(means_gathered, torch.cat(means, dim=0))
dist.all_gather_object(squares_gathered, torch.cat(squares, dim=0))
if rank == 0:
means_all = torch.cat(means_gathered, dim=0)
squares_all = torch.cat(squares_gathered, dim=0)
original_indices = ds.get_original_indices()
means = [means_all[i] for i in original_indices]
squares = [squares_all[i] for i in original_indices]
if rank == 0:
# Final computations and saving are done on CPU
means_all = torch.cat(means_gathered, dim=0)
squares_all = torch.cat(squares_gathered, dim=0)
original_indices = ds.get_original_indices()
means_filtered = [means_all[i] for i in original_indices]
squares_filtered = [squares_all[i] for i in original_indices]
mean = torch.mean(torch.stack(means_filtered), dim=0)
second_moment = torch.mean(torch.stack(squares_filtered), dim=0)
if len(means) > 1:
means = torch.stack(means)
squares = torch.stack(squares)
else:
means = means[0]
squares = squares[0]
mean = torch.mean(means, dim=0)
second_moment = torch.mean(squares, dim=0)
std = torch.sqrt(second_moment - mean**2)
torch.save(
mean.cpu(), os.path.join(static_dir_path, "parameter_mean.pt")
) # Ensure tensor is on CPU
torch.save(
std.cpu(), os.path.join(static_dir_path, "parameter_std.pt")
) # Ensure tensor is on CPU

# flux_means_filtered = [flux_means[i] for i in original_indices]
# flux_squares_filtered = [flux_squares[i] for i in original_indices]
flux_means_all = torch.stack(flux_means)
flux_squares_all = torch.stack(flux_squares)
)
torch.save(std.cpu(), os.path.join(static_dir_path, "parameter_std.pt"))
if len(flux_means) > 1:
flux_means_all = torch.stack(flux_means)
flux_squares_all = torch.stack(flux_squares)
else:
flux_means_all = flux_means[0]
flux_squares_all = flux_squares[0]
flux_mean = torch.mean(flux_means_all)
flux_second_moment = torch.mean(flux_squares_all)
flux_std = torch.sqrt(flux_second_moment - flux_mean**2)
torch.save(
torch.stack((flux_mean, flux_std)).cpu(),
os.path.join(static_dir_path, "flux_stats.pt"),
) # Ensure tensor is on CPU
# Compute mean and std.-dev. of one-step differences across the dataset
dist.barrier()
)
if args.parallelize:
dist.barrier()
if rank == 0:
print("Computing mean and std.-dev. for one-step differences...")
ds_standard = WeatherDataset(
Expand All @@ -239,11 +262,19 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name
pred_length=63,
standardize=True,
)
ds_standard = PaddedWeatherDataset(ds_standard, world_size, args.batch_size)
if args.parallelize:
ds_standard = PaddedWeatherDataset(
ds_standard,
world_size,
args.batch_size,
duplication_factor=args.duplication_factor,
)

sampler_standard = DistributedSampler(
ds_standard, num_replicas=world_size, rank=rank
)
sampler_standard = DistributedSampler(
ds_standard, num_replicas=world_size, rank=rank
)
else:
sampler_standard = None
loader_standard = torch.utils.data.DataLoader(
ds_standard,
args.batch_size,
Expand All @@ -257,9 +288,10 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name
diff_squares = []

for init_batch, target_batch, _ in tqdm(loader_standard, disable=rank != 0):
init_batch, target_batch = init_batch.to(device), target_batch.to(
device
)
if args.parallelize:
init_batch, target_batch = init_batch.to(device), target_batch.to(
device
)
batch = torch.cat((init_batch, target_batch), dim=1)
stepped_batch = torch.cat(
[
Expand All @@ -270,41 +302,44 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name
)
batch_diffs = stepped_batch[:, 1:] - stepped_batch[:, :-1]

# Compute means and squares on GPU, then move to CPU for storage
diff_means.append(torch.mean(batch_diffs, dim=(1, 2)).cpu())
diff_squares.append(torch.mean(batch_diffs**2, dim=(1, 2)).cpu())

dist.barrier()
if args.parallelize:
dist.barrier()

diff_means_gathered = [None] * world_size
diff_squares_gathered = [None] * world_size
dist.all_gather_object(diff_means_gathered, torch.cat(diff_means, dim=0))
dist.all_gather_object(
diff_squares_gathered, torch.cat(diff_squares, dim=0)
)
diff_means_gathered = [None] * world_size
diff_squares_gathered = [None] * world_size
dist.all_gather_object(
diff_means_gathered, torch.cat(diff_means, dim=0)
)
dist.all_gather_object(
diff_squares_gathered, torch.cat(diff_squares, dim=0)
)

if rank == 0:
diff_means_all = torch.cat(diff_means_gathered, dim=0)
diff_squares_all = torch.cat(diff_squares_gathered, dim=0)
original_indices = ds_standard.get_original_indices()
diff_means = [diff_means_all[i] for i in original_indices]
diff_squares = [diff_squares_all[i] for i in original_indices]
if rank == 0:
# Concatenate and compute final statistics on CPU
diff_means_all = torch.cat(diff_means_gathered, dim=0)
diff_squares_all = torch.cat(diff_squares_gathered, dim=0)
original_indices = ds_standard.get_original_indices()
diff_means_filtered = [diff_means_all[i] for i in original_indices]
diff_squares_filtered = [diff_squares_all[i] for i in original_indices]
diff_mean = torch.mean(torch.stack(diff_means_filtered), dim=0)
diff_second_moment = torch.mean(
torch.stack(diff_squares_filtered), dim=0
)
if len(diff_means) > 1:
diff_means = torch.stack(diff_means)
diff_squares = torch.stack(diff_squares)
else:
diff_means = diff_means[0]
diff_squares = diff_squares[0]
diff_mean = torch.mean(diff_means, dim=0)
diff_second_moment = torch.mean(diff_squares, dim=0)
diff_std = torch.sqrt(diff_second_moment - diff_mean**2)

# Save tensors to disk, ensuring they are on CPU
torch.save(diff_mean, os.path.join(static_dir_path, "diff_mean.pt"))
torch.save(diff_std, os.path.join(static_dir_path, "diff_std.pt"))

dist.destroy_process_group()
if args.parallelize:
dist.destroy_process_group()


if __name__ == "__main__":
rank = get_rank()
world_size = get_world_size()
setup(rank, world_size)
main(rank, world_size)
main()

0 comments on commit 743a52c

Please sign in to comment.