Skip to content

Commit

Permalink
Start setting up
Browse files Browse the repository at this point in the history
  • Loading branch information
kaseris committed Feb 9, 2024
1 parent 53b6b07 commit d7408bf
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion src/skelcast/experiments/distributed.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,36 @@
import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel
from torch.distributed import init_process_group, destroy_process_group

from skelcast.models import SkelcastModule
from skelcast.experiments import RUNNERS

@RUNNERS.register_module()
class DistributedRunner:
pass

def __init__(self,
train_set: Dataset,
val_set: Dataset,
train_batch_size: int,
val_batch_size: int,
block_size: int,
model: SkelcastModule,
optimizer: torch.optim.Optimizer = None,) -> None:

self.train_set = train_set
self.val_set = val_set
self.train_batch_size = train_batch_size
self.val_batch_size = val_batch_size
self.block_size = block_size
self.model = model
self.optimizer = optimizer
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = self.model.to(self.device)
self.train_sampler = DistributedSampler(self.train_set)
self.val_sampler = DistributedSampler(self.val_set)
self.train_loader = DataLoader(self.train_set, batch_size=self.train_batch_size, sampler=self.train_sampler)
self.val_loader = DataLoader(self.val_set, batch_size=self.val_batch_size, sampler=self.val_sampler)
self.model = DistributedDataParallel(self.model, device_ids=[self.device])

0 comments on commit d7408bf

Please sign in to comment.