Skip to content

Commit

Permalink
linting and printing improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon Adamov committed May 2, 2024
1 parent 46bb470 commit 6d66b1e
Showing 1 changed file with 31 additions and 23 deletions.
54 changes: 31 additions & 23 deletions create_parameter_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,14 @@ def get_rank():
"""Get the rank of the current process in the distributed group."""
if "SLURM_PROCID" in os.environ:
return int(os.environ["SLURM_PROCID"])
parser = ArgumentParser()
parser.add_argument(
"--rank", type=int, default=0, help="Rank of the current process"
)
args, _ = parser.parse_known_args()
return args.rank
return 0


def get_world_size():
"""Get the number of processes in the distributed group."""
if "SLURM_NTASKS" in os.environ:
return int(os.environ["SLURM_NTASKS"])
parser = ArgumentParser()
parser.add_argument(
"--world_size",
type=int,
default=1,
help="Number of processes in the distributed group",
)
args, _ = parser.parse_known_args()
return args.world_size
return 1


def setup(rank, world_size): # pylint: disable=redefined-outer-name
Expand All @@ -63,13 +50,36 @@ def setup(rank, world_size): # pylint: disable=redefined-outer-name
dist.init_process_group("nccl", rank=rank, world_size=world_size)
else:
dist.init_process_group("gloo", rank=rank, world_size=world_size)
print(
f"Initialized {dist.get_backend()} process group with "
f"world size "
f"{world_size}."
)


def cleanup():
"""Destroy the distributed group."""
dist.destroy_process_group()


def adjust_dataset_size(ds, world_size, batch_size):
# pylint: disable=redefined-outer-name
"""Adjust the dataset size to be divisible by world_size * batch_size."""
total_samples = len(ds)
subset_samples = (total_samples // (world_size * batch_size)) * (
world_size * batch_size
)

if subset_samples != total_samples:
ds = torch.utils.data.Subset(ds, range(subset_samples))
print(
f"Dataset size adjusted from {total_samples} to "
f"{subset_samples} to be divisible by (world_size * batch_size)."
)

return ds


def main(rank, world_size): # pylint: disable=redefined-outer-name
"""Compute the mean and standard deviation of the input data."""
setup(rank, world_size)
Expand Down Expand Up @@ -100,11 +110,6 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name
)
args = parser.parse_args()

if args.subset % (world_size * args.batch_size) != 0:
raise ValueError(
"Subset size must be divisible by (world_size * batch_size)"
)

device = torch.device(
f"cuda:{rank % torch.cuda.device_count()}"
if torch.cuda.is_available()
Expand Down Expand Up @@ -140,6 +145,8 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name
standardize=False,
) # Without standardization

ds = adjust_dataset_size(ds, world_size, args.batch_size)

train_sampler = DistributedSampler(ds, num_replicas=world_size, rank=rank)
loader = torch.utils.data.DataLoader(
ds,
Expand Down Expand Up @@ -202,6 +209,9 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name
pred_length=63,
standardize=True,
) # Re-load with standardization

ds_standard = adjust_dataset_size(ds_standard, world_size, args.batch_size)

sampler_standard = DistributedSampler(
ds_standard, num_replicas=world_size, rank=rank
)
Expand All @@ -217,9 +227,7 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name
diff_means = []
diff_squares = []

for init_batch, target_batch, _, _ in tqdm(
loader_standard, disable=rank != 0
):
for init_batch, target_batch, _ in tqdm(loader_standard, disable=rank != 0):
batch = torch.cat((init_batch, target_batch), dim=1).to(device)
# Note: batch contains only 1h-steps
stepped_batch = torch.cat(
Expand Down

0 comments on commit 6d66b1e

Please sign in to comment.