Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #30

Merged
merged 5 commits into from
Nov 29, 2023
Merged

Dev #30

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 108 additions & 84 deletions src/skelcast/experiments/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,48 @@
from skelcast.logger.base import BaseLogger

class Runner:
"""
A training and validation runner for models in the Skelcast framework.

This class handles the setup, training, validation, and checkpointing of SkelcastModule models.
It uses datasets for training and validation, and includes functionality for batch processing,
gradient logging, and checkpoint management.

Attributes:
train_set (Dataset): The dataset for training.
val_set (Dataset): The dataset for validation.
train_batch_size (int): Batch size for the training dataset.
val_batch_size (int): Batch size for the validation dataset.
block_size (int): Block size used for collating batch data.
model (SkelcastModule): The model to be trained and validated.
optimizer (torch.optim.Optimizer): Optimizer for model training.
n_epochs (int): Number of epochs to train the model.
device (str): The device ('cpu' or 'cuda') on which to run the model.
checkpoint_dir (str): Directory to save checkpoints.
checkpoint_frequency (int): Frequency (in epochs) at which to save checkpoints.
logger (BaseLogger): Logger for recording training and validation metrics.
log_gradient_info (bool): Flag to determine if gradient information is logged.

Methods:
setup(): Prepares the runner for training and validation.
fit(): Starts the training process from epoch 0.
resume(checkpoint_path): Resumes training from a saved checkpoint.
training_step(train_batch): Executes a single training step.
validation_step(val_batch): Executes a single validation step.
_run_epochs(start_epoch): Runs training and validation for specified epochs.
_run_phase(phase, epoch): Runs a training or validation phase for a single epoch.
_log_epoch_loss(phase, epoch): Logs the loss for a completed epoch.
_restore_state(checkpoint): Restores the state of the model and optimizer from a checkpoint.
_compile_results(): Compiles and returns training and validation results.

Note:
- This class requires a properly formatted SkelcastModule model and corresponding datasets.
- The checkpoint directory must exist before initializing the Runner.
- Logging and checkpointing are optional and can be configured as needed.

Raises:
AssertionError: If the checkpoint directory does not exist.
"""
def __init__(self,
train_set: Dataset,
val_set: Dataset,
Expand All @@ -24,7 +66,8 @@ def __init__(self,
device: str = 'cpu',
checkpoint_dir: str = None,
checkpoint_frequency: int = 1,
logger: BaseLogger = None) -> None:
logger: BaseLogger = None,
log_gradient_info: bool = False) -> None:
self.train_set = train_set
self.val_set = val_set
self.train_batch_size = train_batch_size
Expand Down Expand Up @@ -62,6 +105,7 @@ def __init__(self,
self.checkpoint_callback = CheckpointCallback(checkpoint_dir=self.checkpoint_dir,
frequency=self.checkpoint_frequency)
self.logger = logger
self.log_gradient_info = log_gradient_info

def setup(self):
self.model.to(self.device)
Expand All @@ -71,48 +115,87 @@ def setup(self):
self.console_callback.training_batches = self._total_train_batches
self.console_callback.validation_batches = self._total_val_batches


def fit(self):
for epoch in range(self.n_epochs):
def _run_epochs(self, start_epoch):
for epoch in range(start_epoch, self.n_epochs):
self.console_callback.on_epoch_start(epoch=epoch)
for train_batch_idx, train_batch in enumerate(self.train_loader):
self.training_step(train_batch=train_batch)
self.console_callback.on_batch_end(batch_idx=train_batch_idx,
loss=self.training_loss_per_step[-1],
phase='train')
epoch_loss = sum(self.training_loss_per_step[epoch * self._total_train_batches:(epoch + 1) * self._total_train_batches]) / self._total_train_batches
self.console_callback.on_epoch_end(epoch=epoch,
epoch_loss=epoch_loss, phase='train')
self.logger.add_scalar(tag='train/epoch_loss', scalar_value=epoch_loss, global_step=epoch)
self.training_loss_history.append(epoch_loss)
for val_batch_idx, val_batch in enumerate(self.val_loader):
self.validation_step(val_batch=val_batch)
self.console_callback.on_batch_end(batch_idx=val_batch_idx,
loss=self.validation_loss_per_step[-1],
phase='val')
epoch_loss = sum(self.validation_loss_per_step[epoch * self._total_val_batches:(epoch + 1) * self._total_val_batches]) / self._total_val_batches
self.console_callback.on_epoch_end(epoch=epoch, epoch_loss=epoch_loss, phase='val')
self.validation_loss_history.append(epoch_loss)
self._run_phase('train', epoch)
self._log_epoch_loss('train', epoch)
self._run_phase('val', epoch)
self._log_epoch_loss('val', epoch)
self.checkpoint_callback.on_epoch_end(epoch=epoch, runner=self)
self.logger.add_scalar(tag='val/epoch_loss', scalar_value=epoch_loss, global_step=epoch)

def _run_phase(self, phase, epoch):
loader = self.train_loader if phase == 'train' else self.val_loader
step_method = self.training_step if phase == 'train' else self.validation_step
loss_per_step = self.training_loss_per_step if phase == 'train' else self.validation_loss_per_step

for batch_idx, batch in enumerate(loader):
step_method(batch)
self.console_callback.on_batch_end(batch_idx=batch_idx,
loss=loss_per_step[-1],
phase=phase)

def _log_epoch_loss(self, phase, epoch):
loss_per_step = self.training_loss_per_step if phase == 'train' else self.validation_loss_per_step
total_batches = self._total_train_batches if phase == 'train' else self._total_val_batches
epoch_loss = sum(loss_per_step[epoch * total_batches:(epoch + 1) * total_batches]) / total_batches
self.console_callback.on_epoch_end(epoch=epoch, epoch_loss=epoch_loss, phase=phase)
history = self.training_loss_history if phase == 'train' else self.validation_loss_history
history.append(epoch_loss)
self.logger.add_scalar(tag=f'{phase}/epoch_loss', scalar_value=epoch_loss, global_step=epoch)

def resume(self, checkpoint_path):
checkpoint = torch.load(checkpoint_path)
self._restore_state(checkpoint)
start_epoch = checkpoint.get('epoch', 0) + 1
self._run_epochs(start_epoch)
return self._compile_results()

def _restore_state(self, checkpoint):
self.model.load_state_dict(checkpoint.get('model_state_dict'))
self.optimizer.load_state_dict(checkpoint.get('optimizer_state_dict'))
self.training_loss_history = checkpoint.get('training_loss_history')
self.validation_loss_history = checkpoint.get('validation_loss_history')
self.training_loss_per_step = checkpoint.get('training_loss_per_step', [])
self.validation_loss_per_step = checkpoint.get('validation_loss_per_step', [])

def _compile_results(self):
return {
'training_loss_history': self.training_loss_history,
'training_loss_per_step': self.training_loss_per_step,
'validation_loss_history': self.validation_loss_history,
'validation_loss_per_step': self.validation_loss_per_step
}

def fit(self):
self._run_epochs(start_epoch=0)
return self._compile_results()

def training_step(self, train_batch: NTURGBDSample):
x, y = train_batch.x, train_batch.y
# Cast them to a torch float32 and move them to the gpu
x, y = x.to(torch.float32), y.to(torch.float32)
x, y = x.to(self.device), y.to(self.device)

self.model.train()
out = self.model.training_step(x, y)
loss = out['loss']
self.optimizer.zero_grad()
loss.backward()
if self.log_gradient_info:
# Get the gradient flow and update norm ratio
self.model.gradient_flow()
self.model.compute_gradient_update_norm(lr=self.optimizer.param_groups[0]['lr'])
grad_hists = self.model.get_gradient_histograms()
# Log the gradient histograms to the logger
if self.logger is not None:
for name, hist in grad_hists.items():
self.logger.add_histogram(tag=f'gradient/hists/{name}_grad_hist', values=hist, global_step=len(self.training_loss_per_step))

# Log the gradient updates to the logger
if self.logger is not None:
for name, ratio in self.model.gradient_update_ratios.items():
self.logger.add_scalar(tag=f'gradient/{name}_grad_update_norm_ratio', scalar_value=ratio, global_step=len(self.training_loss_per_step))

self.optimizer.step()
# Print the loss
self.training_loss_per_step.append(loss.item())
Expand All @@ -125,70 +208,11 @@ def validation_step(self, val_batch: NTURGBDSample):
# Cast them to a torch float32 and move them to the gpu
x, y = x.to(torch.float32), y.to(torch.float32)
x, y = x.to(self.device), y.to(self.device)

self.model.eval()
out = self.model.validation_step(x, y)
loss = out['loss']
self.validation_loss_per_step.append(loss.item())
# Log it to the logger
if self.logger is not None:
self.logger.add_scalar(tag='val/step_loss', scalar_value=loss.item(), global_step=len(self.validation_loss_per_step))

def resume(self, checkpoint_path):
"""
Resumes training from a saved checkpoint.

Args:

- checkpoint_path: Path to the checkpoint file.
"""
# Load the checkpoint
checkpoint = torch.load(checkpoint_path)

# Restore the previous' epoch's state
self.model.load_state_dict(checkpoint.get('model_state_dict'))
self.optimizer.load_state_dict(checkpoint.get('optimizer_state_dict'))
self.training_loss_history = checkpoint.get('training_loss_history')
self.validation_loss_history = checkpoint.get('validation_loss_history')
self.training_loss_per_step = checkpoint.get('training_loss_per_step', [])
self.validation_loss_per_step = checkpoint.get('validation_loss_per_step', [])

# Set the current epoch to the loaded epoch and start from the next
start_epoch = checkpoint.get('epoch', 0) + 1

# resume the training
for epoch in range(start_epoch, self.n_epochs):
self.console_callback.on_epoch_start(epoch=epoch)
for train_batch_idx, train_batch in enumerate(self.train_loader):
self.training_step(train_batch=train_batch)
self.console_callback.on_batch_end(batch_idx=train_batch_idx,
loss=self.training_loss_per_step[-1],
phase='train')
if self.logger is not None:
self.logger.add_scalar(tag='train/step_loss', scalar_value=self.training_loss_per_step[-1], global_step=len(self.training_loss_per_step))
epoch_loss = sum(self.training_loss_per_step[epoch * self._total_train_batches:(epoch + 1) * self._total_train_batches]) / self._total_train_batches
self.console_callback.on_epoch_end(epoch=epoch,
epoch_loss=epoch_loss, phase='train')
self.training_loss_history.append(epoch_loss)
if self.logger is not None:
self.logger.add_scalar(tag='train/epoch_loss', scalar_value=epoch_loss, global_step=epoch)
for val_batch_idx, val_batch in enumerate(self.val_loader):
self.validation_step(val_batch=val_batch)
self.console_callback.on_batch_end(batch_idx=val_batch_idx,
loss=self.validation_loss_per_step[-1],
phase='val')
if self.logger is not None:
self.logger.add_scalar(tag='val/step_loss', scalar_value=self.validation_loss_per_step[-1], global_step=len(self.validation_loss_per_step))
epoch_loss = sum(self.validation_loss_per_step[epoch * self._total_val_batches:(epoch + 1) * self._total_val_batches]) / self._total_val_batches
self.console_callback.on_epoch_end(epoch=epoch, epoch_loss=epoch_loss, phase='val')
self.validation_loss_history.append(epoch_loss)
if self.logger is not None:
self.logger.add_scalar(tag='val/epoch_loss', scalar_value=epoch_loss, global_step=epoch)
self.checkpoint_callback.on_epoch_end(epoch=epoch, runner=self)

return {
'training_loss_history': self.training_loss_history,
'training_loss_per_step': self.training_loss_per_step,
'validation_loss_history': self.validation_loss_history,
'validation_loss_per_step': self.validation_loss_per_step
}

Loading