Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/distributed #81

Merged
merged 4 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/skelcast/callbacks/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import time
from datetime import datetime

import torch.distributed as dist

from skelcast.callbacks.callback import Callback


Expand All @@ -15,6 +17,10 @@ def __init__(self):
self.validation_batches = 0
self.training_batches = 0
self.total_batches = 0
if dist.is_initialized():
self.rank = dist.get_rank()
else:
self.rank = None

def on_epoch_start(self, epoch):
self.current_epoch = epoch
Expand Down Expand Up @@ -51,7 +57,8 @@ def _print_status(self):
now = datetime.now()
now_formatted = now.strftime("[%Y-%m-%d %H:%M:%S]")
clear_line = '\r' + ' ' * 80 # Create a line of 80 spaces
message = f"{now_formatted} Epoch: {self.current_epoch + 1}/{self.final_epoch}, Batch: {self.current_batch}/{self.total_batches}, Train Loss: {self.latest_train_loss}, Val Loss: {self.latest_val_loss}"
rank_info = f"Rank: {self.rank}, " if self.rank is not None else ""
message = f"{now_formatted} {rank_info} Epoch: {self.current_epoch + 1}/{self.final_epoch}, Batch: {self.current_batch}/{self.total_batches}, Train Loss: {self.latest_train_loss}, Val Loss: {self.latest_val_loss}"

# First, print the clear_line to overwrite the previous output, then print your message
print(f'{clear_line}\r{message}', end='')
Expand Down
164 changes: 162 additions & 2 deletions src/skelcast/experiments/distributed.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import os

import torch
import torch.distributed as dist

from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel
from torch.distributed import init_process_group, destroy_process_group

from skelcast.callbacks.console import ConsoleCallback
from skelcast.callbacks.checkpoint import CheckpointCallback
from skelcast.logger.base import BaseLogger
from skelcast.models import SkelcastModule
from skelcast.experiments import RUNNERS

Expand All @@ -17,7 +23,15 @@ def __init__(self,
val_batch_size: int,
block_size: int,
model: SkelcastModule,
optimizer: torch.optim.Optimizer = None,) -> None:
optimizer: torch.optim.Optimizer = None,
lr: float = 1e-4,
n_epochs: int = 10,
checkpoint_dir: str = None,
checkpoint_frequency: int = 1,
logger: BaseLogger = None,
log_gradient_info: bool = False,
collate_fn = None
) -> None:

self.train_set = train_set
self.val_set = val_set
Expand All @@ -32,5 +46,151 @@ def __init__(self,
self.val_sampler = DistributedSampler(self.val_set)
self.train_loader = DataLoader(self.train_set, batch_size=self.train_batch_size, sampler=self.train_sampler)
self.val_loader = DataLoader(self.val_set, batch_size=self.val_batch_size, sampler=self.val_sampler)
self.device = int(os.environ.get('LOCAL_RANK'))
self.model = DistributedDataParallel(self.model, device_ids=[self.device])
self.lr = lr

if optimizer is None:
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
else:
self.optimizer = optimizer

self.training_loss_history = []
self.training_loss_per_step = []
self.validation_loss_history = []
self.validation_loss_per_step = []

self.n_epochs = n_epochs

self._status_message = ''

self.console_callback = ConsoleCallback()
self.checkpoint_dir = checkpoint_dir
self.checkpoint_frequency = checkpoint_frequency
assert os.path.exists(self.checkpoint_dir), f'The designated checkpoint directory `{self.checkpoint_dir}` does not exist.'
self.checkpoint_callback = CheckpointCallback(checkpoint_dir=self.checkpoint_dir,
frequency=self.checkpoint_frequency)
self.logger = logger
self.log_gradient_info = log_gradient_info

def setup(self):
self.model.to(self.device)
self._total_train_batches = len(self.train_set) // self.train_batch_size
self._total_val_batches = len(self.val_set) // self.val_batch_size
self.console_callback.final_epoch = self.n_epochs
self.console_callback.training_batches = self._total_train_batches
self.console_callback.validation_batches = self._total_val_batches

def _run_epochs(self, start_epoch):
for epoch in range(start_epoch, self.n_epochs):
self.console_callback.on_epoch_start(epoch=epoch)
self._run_phase('train', epoch)
self._log_epoch_loss('train', epoch)
self._run_phase('val', epoch)
self._log_epoch_loss('val', epoch)
self.checkpoint_callback.on_epoch_end(epoch=epoch, runner=self)

def _run_phase(self, phase, epoch):
loader = self.train_loader if phase == 'train' else self.val_loader
step_method = self.training_step if phase == 'train' else self.validation_step
loss_per_step = self.training_loss_per_step if phase == 'train' else self.validation_loss_per_step

for batch_idx, batch in enumerate(loader):
step_method(batch)
self.console_callback.on_batch_end(batch_idx=batch_idx,
loss=loss_per_step[-1],
phase=phase)

def _log_epoch_loss(self, phase, epoch):
loss_per_step = self.training_loss_per_step if phase == 'train' else self.validation_loss_per_step
total_batches = self._total_train_batches if phase == 'train' else self._total_val_batches
epoch_loss = sum(loss_per_step[epoch * total_batches:(epoch + 1) * total_batches]) / total_batches
self.console_callback.on_epoch_end(epoch=epoch, epoch_loss=epoch_loss, phase=phase)
history = self.training_loss_history if phase == 'train' else self.validation_loss_history
history.append(epoch_loss)
self.logger.add_scalar(tag=f'{phase}/epoch_loss', scalar_value=epoch_loss, global_step=epoch)

def resume(self, checkpoint_path):
checkpoint = torch.load(checkpoint_path)
self._restore_state(checkpoint)
start_epoch = checkpoint.get('epoch', 0) + 1
self._run_epochs(start_epoch)
return self._compile_results()

def _restore_state(self, checkpoint):
self.model.load_state_dict(checkpoint.get('model_state_dict'))
self.optimizer.load_state_dict(checkpoint.get('optimizer_state_dict'))
self.training_loss_history = checkpoint.get('training_loss_history')
self.validation_loss_history = checkpoint.get('validation_loss_history')
self.training_loss_per_step = checkpoint.get('training_loss_per_step', [])
self.validation_loss_per_step = checkpoint.get('validation_loss_per_step', [])

def _compile_results(self):
return {
'training_loss_history': self.training_loss_history,
'training_loss_per_step': self.training_loss_per_step,
'validation_loss_history': self.validation_loss_history,
'validation_loss_per_step': self.validation_loss_per_step
}

def fit(self):
self._run_epochs(start_epoch=0)
return self._compile_results()

def training_step(self, train_batch):
x, y, mask = train_batch.x, train_batch.y, train_batch.mask
# Cast them to a torch float32 and move them to the gpu
# TODO: Handle the mask None case
x, y, mask = x.to(torch.float32), y.to(torch.float32), mask.to(torch.float32)
x, y, mask = x.to(self.device), y.to(self.device), mask.to(self.device)
self.model.train()
out = self.model.training_step(x=x, y=y, mask=mask) # TODO: Make the other models accept a mask as well
loss = out['loss']
outputs = out['out']
# Calculate the saturation of the tanh output
saturated = (outputs.abs() > 0.95)
saturation_percentage = saturated.sum(dim=(1, 2)).float() / (outputs.size(1) * outputs.size(2)) * 100
# Calculate the dead neurons
dead_neurons = (outputs.abs() < 0.05)
dead_neurons_percentage = dead_neurons.sum(dim=(1, 2)).float() / (outputs.size(1) * outputs.size(2)) * 100
self.optimizer.zero_grad()
loss.backward()
if self.log_gradient_info:
# Get the gradient flow and update norm ratio
self.model.gradient_flow()
self.model.compute_gradient_update_norm(lr=self.optimizer.param_groups[0]['lr'])
grad_hists = self.model.get_gradient_histograms()
# Log the gradient histograms to the logger
if self.logger is not None:
for name, hist in grad_hists.items():
self.logger.add_histogram(tag=f'gradient/hists/{name}_grad_hist', values=hist, global_step=len(self.training_loss_per_step))

# Log the gradient updates to the logger
if self.logger is not None:
for name, ratio in self.model.gradient_update_ratios.items():
self.logger.add_scalar(tag=f'gradient/{name}_grad_update_norm_ratio', scalar_value=ratio, global_step=len(self.training_loss_per_step))

if self.logger is not None:
self.logger.add_scalar(tag='train/saturation', scalar_value=saturation_percentage.mean().item(), global_step=len(self.training_loss_per_step))
self.logger.add_scalar(tag='train/dead_neurons', scalar_value=dead_neurons_percentage.mean().item(), global_step=len(self.training_loss_per_step))

self.optimizer.step()
# Print the loss
self.training_loss_per_step.append(loss.item())
# Log it to the logger
if self.logger is not None:
self.logger.add_scalar(tag='train/step_loss', scalar_value=loss.item(), global_step=len(self.training_loss_per_step))

def validation_step(self, val_batch):
x, y, mask = val_batch.x, val_batch.y, val_batch.mask
# Cast them to a torch float32 and move them to the gpu
x, y, mask = x.to(torch.float32), y.to(torch.float32), mask.to(torch.float32)
x, y, mask = x.to(self.device), y.to(self.device), mask.to(self.device)
self.model.eval()
out = self.model.validation_step(x=x, y=y, mask=mask)
loss = out['loss']
self.validation_loss_per_step.append(loss.item())
# Log it to the logger
if self.logger is not None:
self.logger.add_scalar(tag='val/step_loss', scalar_value=loss.item(), global_step=len(self.validation_loss_per_step))

4 changes: 4 additions & 0 deletions tools/train_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from torch.distributed import init_process_group, destroy_process_group

init_process_group(backend='nccl')
destroy_process_group()
Loading