Skip to content

Commit

Permalink
Remove init and destroy from the trainer object.
Browse files Browse the repository at this point in the history
  • Loading branch information
kaseris committed Feb 9, 2024
1 parent b07be21 commit 00bd8ab
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions src/skelcast/experiments/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel
from torch.distributed import init_process_group, destroy_process_group

from skelcast.callbacks.console import ConsoleCallback
from skelcast.callbacks.checkpoint import CheckpointCallback
Expand Down Expand Up @@ -47,6 +46,7 @@ def __init__(self,
self.val_sampler = DistributedSampler(self.val_set)
self.train_loader = DataLoader(self.train_set, batch_size=self.train_batch_size, sampler=self.train_sampler)
self.val_loader = DataLoader(self.val_set, batch_size=self.val_batch_size, sampler=self.val_sampler)
self.device = int(os.environ.get('LOCAL_RANK'))
self.model = DistributedDataParallel(self.model, device_ids=[self.device])
self.lr = lr

Expand Down Expand Up @@ -74,7 +74,6 @@ def __init__(self,
self.log_gradient_info = log_gradient_info

def setup(self):
init_process_group(backend='nccl')
self.model.to(self.device)
self._total_train_batches = len(self.train_set) // self.train_batch_size
self._total_val_batches = len(self.val_set) // self.val_batch_size
Expand Down Expand Up @@ -136,7 +135,6 @@ def _compile_results(self):

def fit(self):
self._run_epochs(start_epoch=0)
destroy_process_group()
return self._compile_results()

def training_step(self, train_batch):
Expand Down

0 comments on commit 00bd8ab

Please sign in to comment.