Skip to content

Commit

Permalink
Merge pull request #80 from kaseris/feature/distributed
Browse files Browse the repository at this point in the history
Feature/distributed
  • Loading branch information
kaseris committed Feb 9, 2024
2 parents 045a024 + d7408bf commit 5e11474
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/skelcast/experiments/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from skelcast.core.registry import Registry

RUNNERS = Registry()

from .distributed import DistributedRunner
36 changes: 36 additions & 0 deletions src/skelcast/experiments/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel
from torch.distributed import init_process_group, destroy_process_group

from skelcast.models import SkelcastModule
from skelcast.experiments import RUNNERS

@RUNNERS.register_module()
class DistributedRunner:

def __init__(self,
train_set: Dataset,
val_set: Dataset,
train_batch_size: int,
val_batch_size: int,
block_size: int,
model: SkelcastModule,
optimizer: torch.optim.Optimizer = None,) -> None:

self.train_set = train_set
self.val_set = val_set
self.train_batch_size = train_batch_size
self.val_batch_size = val_batch_size
self.block_size = block_size
self.model = model
self.optimizer = optimizer
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = self.model.to(self.device)
self.train_sampler = DistributedSampler(self.train_set)
self.val_sampler = DistributedSampler(self.val_set)
self.train_loader = DataLoader(self.train_set, batch_size=self.train_batch_size, sampler=self.train_sampler)
self.val_loader = DataLoader(self.val_set, batch_size=self.val_batch_size, sampler=self.val_sampler)
self.model = DistributedDataParallel(self.model, device_ids=[self.device])

0 comments on commit 5e11474

Please sign in to comment.