From 5f34fa7d9c0a38f24ef213dec29a5f8cad3fbce4 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Mon, 5 Feb 2024 20:24:21 +0200 Subject: [PATCH 1/3] Distributed data parallel runner --- src/skelcast/experiments/__init__.py | 5 +++++ src/skelcast/experiments/distributed.py | 5 +++++ 2 files changed, 10 insertions(+) create mode 100644 src/skelcast/experiments/distributed.py diff --git a/src/skelcast/experiments/__init__.py b/src/skelcast/experiments/__init__.py index e69de29..7733a21 100644 --- a/src/skelcast/experiments/__init__.py +++ b/src/skelcast/experiments/__init__.py @@ -0,0 +1,5 @@ +from skelcast.core.registry import Registry + +RUNNERS = Registry() + +from .distributed import DistributedRunner \ No newline at end of file diff --git a/src/skelcast/experiments/distributed.py b/src/skelcast/experiments/distributed.py new file mode 100644 index 0000000..3e5c343 --- /dev/null +++ b/src/skelcast/experiments/distributed.py @@ -0,0 +1,5 @@ +from skelcast.experiments import RUNNERS + +@RUNNERS.register_module() +class DistributedRunner: + pass \ No newline at end of file From 53b6b07c6f4b8bf18539897bc5eb071f30b63de9 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Wed, 7 Feb 2024 22:29:49 +0200 Subject: [PATCH 2/3] Setup the imports --- src/skelcast/experiments/distributed.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/skelcast/experiments/distributed.py b/src/skelcast/experiments/distributed.py index 3e5c343..0be6f70 100644 --- a/src/skelcast/experiments/distributed.py +++ b/src/skelcast/experiments/distributed.py @@ -1,3 +1,7 @@ +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel +from torch.distributed import init_process_group, destroy_process_group + from skelcast.experiments import RUNNERS @RUNNERS.register_module() From d7408bf95e34f31538de01656a890717d80204a0 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Fri, 9 Feb 2024 11:59:28 +0200 Subject: [PATCH 3/3] Start setting up --- src/skelcast/experiments/distributed.py | 29 ++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/src/skelcast/experiments/distributed.py b/src/skelcast/experiments/distributed.py index 0be6f70..9fa2111 100644 --- a/src/skelcast/experiments/distributed.py +++ b/src/skelcast/experiments/distributed.py @@ -1,9 +1,36 @@ +import torch +from torch.utils.data import DataLoader, Dataset from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel from torch.distributed import init_process_group, destroy_process_group +from skelcast.models import SkelcastModule from skelcast.experiments import RUNNERS @RUNNERS.register_module() class DistributedRunner: - pass \ No newline at end of file + + def __init__(self, + train_set: Dataset, + val_set: Dataset, + train_batch_size: int, + val_batch_size: int, + block_size: int, + model: SkelcastModule, + optimizer: torch.optim.Optimizer = None,) -> None: + + self.train_set = train_set + self.val_set = val_set + self.train_batch_size = train_batch_size + self.val_batch_size = val_batch_size + self.block_size = block_size + self.model = model + self.optimizer = optimizer + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.model = self.model.to(self.device) + self.train_sampler = DistributedSampler(self.train_set) + self.val_sampler = DistributedSampler(self.val_set) + self.train_loader = DataLoader(self.train_set, batch_size=self.train_batch_size, sampler=self.train_sampler) + self.val_loader = DataLoader(self.val_set, batch_size=self.val_batch_size, sampler=self.val_sampler) + self.model = DistributedDataParallel(self.model, device_ids=[self.device]) + \ No newline at end of file