diff --git a/src/skelcast/experiments/runner.py b/src/skelcast/experiments/runner.py index 93eb689..bd4cded 100644 --- a/src/skelcast/experiments/runner.py +++ b/src/skelcast/experiments/runner.py @@ -12,6 +12,48 @@ from skelcast.logger.base import BaseLogger class Runner: + """ + A training and validation runner for models in the Skelcast framework. + + This class handles the setup, training, validation, and checkpointing of SkelcastModule models. + It uses datasets for training and validation, and includes functionality for batch processing, + gradient logging, and checkpoint management. + + Attributes: + train_set (Dataset): The dataset for training. + val_set (Dataset): The dataset for validation. + train_batch_size (int): Batch size for the training dataset. + val_batch_size (int): Batch size for the validation dataset. + block_size (int): Block size used for collating batch data. + model (SkelcastModule): The model to be trained and validated. + optimizer (torch.optim.Optimizer): Optimizer for model training. + n_epochs (int): Number of epochs to train the model. + device (str): The device ('cpu' or 'cuda') on which to run the model. + checkpoint_dir (str): Directory to save checkpoints. + checkpoint_frequency (int): Frequency (in epochs) at which to save checkpoints. + logger (BaseLogger): Logger for recording training and validation metrics. + log_gradient_info (bool): Flag to determine if gradient information is logged. + + Methods: + setup(): Prepares the runner for training and validation. + fit(): Starts the training process from epoch 0. + resume(checkpoint_path): Resumes training from a saved checkpoint. + training_step(train_batch): Executes a single training step. + validation_step(val_batch): Executes a single validation step. + _run_epochs(start_epoch): Runs training and validation for specified epochs. + _run_phase(phase, epoch): Runs a training or validation phase for a single epoch. + _log_epoch_loss(phase, epoch): Logs the loss for a completed epoch. + _restore_state(checkpoint): Restores the state of the model and optimizer from a checkpoint. + _compile_results(): Compiles and returns training and validation results. + + Note: + - This class requires a properly formatted SkelcastModule model and corresponding datasets. + - The checkpoint directory must exist before initializing the Runner. + - Logging and checkpointing are optional and can be configured as needed. + + Raises: + AssertionError: If the checkpoint directory does not exist. + """ def __init__(self, train_set: Dataset, val_set: Dataset, @@ -24,7 +66,8 @@ def __init__(self, device: str = 'cpu', checkpoint_dir: str = None, checkpoint_frequency: int = 1, - logger: BaseLogger = None) -> None: + logger: BaseLogger = None, + log_gradient_info: bool = False) -> None: self.train_set = train_set self.val_set = val_set self.train_batch_size = train_batch_size @@ -62,6 +105,7 @@ def __init__(self, self.checkpoint_callback = CheckpointCallback(checkpoint_dir=self.checkpoint_dir, frequency=self.checkpoint_frequency) self.logger = logger + self.log_gradient_info = log_gradient_info def setup(self): self.model.to(self.device) @@ -71,31 +115,51 @@ def setup(self): self.console_callback.training_batches = self._total_train_batches self.console_callback.validation_batches = self._total_val_batches - - def fit(self): - for epoch in range(self.n_epochs): + def _run_epochs(self, start_epoch): + for epoch in range(start_epoch, self.n_epochs): self.console_callback.on_epoch_start(epoch=epoch) - for train_batch_idx, train_batch in enumerate(self.train_loader): - self.training_step(train_batch=train_batch) - self.console_callback.on_batch_end(batch_idx=train_batch_idx, - loss=self.training_loss_per_step[-1], - phase='train') - epoch_loss = sum(self.training_loss_per_step[epoch * self._total_train_batches:(epoch + 1) * self._total_train_batches]) / self._total_train_batches - self.console_callback.on_epoch_end(epoch=epoch, - epoch_loss=epoch_loss, phase='train') - self.logger.add_scalar(tag='train/epoch_loss', scalar_value=epoch_loss, global_step=epoch) - self.training_loss_history.append(epoch_loss) - for val_batch_idx, val_batch in enumerate(self.val_loader): - self.validation_step(val_batch=val_batch) - self.console_callback.on_batch_end(batch_idx=val_batch_idx, - loss=self.validation_loss_per_step[-1], - phase='val') - epoch_loss = sum(self.validation_loss_per_step[epoch * self._total_val_batches:(epoch + 1) * self._total_val_batches]) / self._total_val_batches - self.console_callback.on_epoch_end(epoch=epoch, epoch_loss=epoch_loss, phase='val') - self.validation_loss_history.append(epoch_loss) + self._run_phase('train', epoch) + self._log_epoch_loss('train', epoch) + self._run_phase('val', epoch) + self._log_epoch_loss('val', epoch) self.checkpoint_callback.on_epoch_end(epoch=epoch, runner=self) - self.logger.add_scalar(tag='val/epoch_loss', scalar_value=epoch_loss, global_step=epoch) + def _run_phase(self, phase, epoch): + loader = self.train_loader if phase == 'train' else self.val_loader + step_method = self.training_step if phase == 'train' else self.validation_step + loss_per_step = self.training_loss_per_step if phase == 'train' else self.validation_loss_per_step + + for batch_idx, batch in enumerate(loader): + step_method(batch) + self.console_callback.on_batch_end(batch_idx=batch_idx, + loss=loss_per_step[-1], + phase=phase) + + def _log_epoch_loss(self, phase, epoch): + loss_per_step = self.training_loss_per_step if phase == 'train' else self.validation_loss_per_step + total_batches = self._total_train_batches if phase == 'train' else self._total_val_batches + epoch_loss = sum(loss_per_step[epoch * total_batches:(epoch + 1) * total_batches]) / total_batches + self.console_callback.on_epoch_end(epoch=epoch, epoch_loss=epoch_loss, phase=phase) + history = self.training_loss_history if phase == 'train' else self.validation_loss_history + history.append(epoch_loss) + self.logger.add_scalar(tag=f'{phase}/epoch_loss', scalar_value=epoch_loss, global_step=epoch) + + def resume(self, checkpoint_path): + checkpoint = torch.load(checkpoint_path) + self._restore_state(checkpoint) + start_epoch = checkpoint.get('epoch', 0) + 1 + self._run_epochs(start_epoch) + return self._compile_results() + + def _restore_state(self, checkpoint): + self.model.load_state_dict(checkpoint.get('model_state_dict')) + self.optimizer.load_state_dict(checkpoint.get('optimizer_state_dict')) + self.training_loss_history = checkpoint.get('training_loss_history') + self.validation_loss_history = checkpoint.get('validation_loss_history') + self.training_loss_per_step = checkpoint.get('training_loss_per_step', []) + self.validation_loss_per_step = checkpoint.get('validation_loss_per_step', []) + + def _compile_results(self): return { 'training_loss_history': self.training_loss_history, 'training_loss_per_step': self.training_loss_per_step, @@ -103,16 +167,35 @@ def fit(self): 'validation_loss_per_step': self.validation_loss_per_step } + def fit(self): + self._run_epochs(start_epoch=0) + return self._compile_results() + def training_step(self, train_batch: NTURGBDSample): x, y = train_batch.x, train_batch.y # Cast them to a torch float32 and move them to the gpu x, y = x.to(torch.float32), y.to(torch.float32) x, y = x.to(self.device), y.to(self.device) - + self.model.train() out = self.model.training_step(x, y) loss = out['loss'] self.optimizer.zero_grad() loss.backward() + if self.log_gradient_info: + # Get the gradient flow and update norm ratio + self.model.gradient_flow() + self.model.compute_gradient_update_norm(lr=self.optimizer.param_groups[0]['lr']) + grad_hists = self.model.get_gradient_histograms() + # Log the gradient histograms to the logger + if self.logger is not None: + for name, hist in grad_hists.items(): + self.logger.add_histogram(tag=f'gradient/hists/{name}_grad_hist', values=hist, global_step=len(self.training_loss_per_step)) + + # Log the gradient updates to the logger + if self.logger is not None: + for name, ratio in self.model.gradient_update_ratios.items(): + self.logger.add_scalar(tag=f'gradient/{name}_grad_update_norm_ratio', scalar_value=ratio, global_step=len(self.training_loss_per_step)) + self.optimizer.step() # Print the loss self.training_loss_per_step.append(loss.item()) @@ -125,70 +208,11 @@ def validation_step(self, val_batch: NTURGBDSample): # Cast them to a torch float32 and move them to the gpu x, y = x.to(torch.float32), y.to(torch.float32) x, y = x.to(self.device), y.to(self.device) - + self.model.eval() out = self.model.validation_step(x, y) loss = out['loss'] self.validation_loss_per_step.append(loss.item()) # Log it to the logger if self.logger is not None: self.logger.add_scalar(tag='val/step_loss', scalar_value=loss.item(), global_step=len(self.validation_loss_per_step)) - - def resume(self, checkpoint_path): - """ - Resumes training from a saved checkpoint. - - Args: - - - checkpoint_path: Path to the checkpoint file. - """ - # Load the checkpoint - checkpoint = torch.load(checkpoint_path) - - # Restore the previous' epoch's state - self.model.load_state_dict(checkpoint.get('model_state_dict')) - self.optimizer.load_state_dict(checkpoint.get('optimizer_state_dict')) - self.training_loss_history = checkpoint.get('training_loss_history') - self.validation_loss_history = checkpoint.get('validation_loss_history') - self.training_loss_per_step = checkpoint.get('training_loss_per_step', []) - self.validation_loss_per_step = checkpoint.get('validation_loss_per_step', []) - - # Set the current epoch to the loaded epoch and start from the next - start_epoch = checkpoint.get('epoch', 0) + 1 - - # resume the training - for epoch in range(start_epoch, self.n_epochs): - self.console_callback.on_epoch_start(epoch=epoch) - for train_batch_idx, train_batch in enumerate(self.train_loader): - self.training_step(train_batch=train_batch) - self.console_callback.on_batch_end(batch_idx=train_batch_idx, - loss=self.training_loss_per_step[-1], - phase='train') - if self.logger is not None: - self.logger.add_scalar(tag='train/step_loss', scalar_value=self.training_loss_per_step[-1], global_step=len(self.training_loss_per_step)) - epoch_loss = sum(self.training_loss_per_step[epoch * self._total_train_batches:(epoch + 1) * self._total_train_batches]) / self._total_train_batches - self.console_callback.on_epoch_end(epoch=epoch, - epoch_loss=epoch_loss, phase='train') - self.training_loss_history.append(epoch_loss) - if self.logger is not None: - self.logger.add_scalar(tag='train/epoch_loss', scalar_value=epoch_loss, global_step=epoch) - for val_batch_idx, val_batch in enumerate(self.val_loader): - self.validation_step(val_batch=val_batch) - self.console_callback.on_batch_end(batch_idx=val_batch_idx, - loss=self.validation_loss_per_step[-1], - phase='val') - if self.logger is not None: - self.logger.add_scalar(tag='val/step_loss', scalar_value=self.validation_loss_per_step[-1], global_step=len(self.validation_loss_per_step)) - epoch_loss = sum(self.validation_loss_per_step[epoch * self._total_val_batches:(epoch + 1) * self._total_val_batches]) / self._total_val_batches - self.console_callback.on_epoch_end(epoch=epoch, epoch_loss=epoch_loss, phase='val') - self.validation_loss_history.append(epoch_loss) - if self.logger is not None: - self.logger.add_scalar(tag='val/epoch_loss', scalar_value=epoch_loss, global_step=epoch) - self.checkpoint_callback.on_epoch_end(epoch=epoch, runner=self) - - return { - 'training_loss_history': self.training_loss_history, - 'training_loss_per_step': self.training_loss_per_step, - 'validation_loss_history': self.validation_loss_history, - 'validation_loss_per_step': self.validation_loss_per_step - } \ No newline at end of file