Skip to content

Commit

Permalink
Merge pull request #83 from kaseris/fix/runner
Browse files Browse the repository at this point in the history
Finally fix the case of NoneType masks.
  • Loading branch information
kaseris committed Feb 9, 2024
2 parents c82a257 + 85c6e88 commit 33665e1
Showing 1 changed file with 163 additions and 80 deletions.
243 changes: 163 additions & 80 deletions src/skelcast/experiments/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
from skelcast.callbacks.checkpoint import CheckpointCallback
from skelcast.logger.base import BaseLogger


class Runner:
"""
A training and validation runner for models in the Skelcast framework.
This class handles the setup, training, validation, and checkpointing of SkelcastModule models.
It uses datasets for training and validation, and includes functionality for batch processing,
This class handles the setup, training, validation, and checkpointing of SkelcastModule models.
It uses datasets for training and validation, and includes functionality for batch processing,
gradient logging, and checkpoint management.
Expand Down Expand Up @@ -59,30 +60,47 @@ class Runner:
---
`AssertionError`: If the checkpoint directory does not exist.
"""
def __init__(self,
train_set: Dataset,
val_set: Dataset,
train_batch_size: int,
val_batch_size: int,
block_size: int,
model: SkelcastModule,
optimizer: torch.optim.Optimizer = None,
lr: float = 1e-4,
n_epochs: int = 10,
device: str = 'cpu',
checkpoint_dir: str = None,
checkpoint_frequency: int = 1,
logger: BaseLogger = None,
log_gradient_info: bool = False,
collate_fn = None) -> None:

def __init__(
self,
train_set: Dataset,
val_set: Dataset,
train_batch_size: int,
val_batch_size: int,
block_size: int,
model: SkelcastModule,
optimizer: torch.optim.Optimizer = None,
lr: float = 1e-4,
n_epochs: int = 10,
device: str = "cpu",
checkpoint_dir: str = None,
checkpoint_frequency: int = 1,
logger: BaseLogger = None,
log_gradient_info: bool = False,
collate_fn=None,
) -> None:
self.train_set = train_set
self.val_set = val_set
self.train_batch_size = train_batch_size
self.val_batch_size = val_batch_size
self.block_size = block_size
self._collate_fn = collate_fn if collate_fn is not None else NTURGBDCollateFn(block_size=self.block_size)
self.train_loader = DataLoader(dataset=self.train_set, batch_size=self.train_batch_size, shuffle=True, collate_fn=self._collate_fn)
self.val_loader = DataLoader(dataset=self.val_set, batch_size=self.val_batch_size, shuffle=False, collate_fn=self._collate_fn)
self._collate_fn = (
collate_fn
if collate_fn is not None
else NTURGBDCollateFn(block_size=self.block_size)
)
self.train_loader = DataLoader(
dataset=self.train_set,
batch_size=self.train_batch_size,
shuffle=True,
collate_fn=self._collate_fn,
)
self.val_loader = DataLoader(
dataset=self.val_set,
batch_size=self.val_batch_size,
shuffle=False,
collate_fn=self._collate_fn,
)
self.model = model
self.lr = lr

Expand All @@ -98,20 +116,23 @@ def __init__(self,

self.n_epochs = n_epochs

self._status_message = ''
self._status_message = ""

if device != 'cpu':
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device != "cpu":
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device('cpu')
self.device = torch.device("cpu")

self.console_callback = ConsoleCallback()

self.checkpoint_dir = checkpoint_dir
self.checkpoint_frequency = checkpoint_frequency
assert os.path.exists(self.checkpoint_dir), f'The designated checkpoint directory `{self.checkpoint_dir}` does not exist.'
self.checkpoint_callback = CheckpointCallback(checkpoint_dir=self.checkpoint_dir,
frequency=self.checkpoint_frequency)
assert os.path.exists(
self.checkpoint_dir
), f"The designated checkpoint directory `{self.checkpoint_dir}` does not exist."
self.checkpoint_callback = CheckpointCallback(
checkpoint_dir=self.checkpoint_dir, frequency=self.checkpoint_frequency
)
self.logger = logger
self.log_gradient_info = log_gradient_info

Expand All @@ -126,53 +147,74 @@ def setup(self):
def _run_epochs(self, start_epoch):
for epoch in range(start_epoch, self.n_epochs):
self.console_callback.on_epoch_start(epoch=epoch)
self._run_phase('train', epoch)
self._log_epoch_loss('train', epoch)
self._run_phase('val', epoch)
self._log_epoch_loss('val', epoch)
self._run_phase("train", epoch)
self._log_epoch_loss("train", epoch)
self._run_phase("val", epoch)
self._log_epoch_loss("val", epoch)
self.checkpoint_callback.on_epoch_end(epoch=epoch, runner=self)

def _run_phase(self, phase, epoch):
loader = self.train_loader if phase == 'train' else self.val_loader
step_method = self.training_step if phase == 'train' else self.validation_step
loss_per_step = self.training_loss_per_step if phase == 'train' else self.validation_loss_per_step
loader = self.train_loader if phase == "train" else self.val_loader
step_method = self.training_step if phase == "train" else self.validation_step
loss_per_step = (
self.training_loss_per_step
if phase == "train"
else self.validation_loss_per_step
)

for batch_idx, batch in enumerate(loader):
step_method(batch)
self.console_callback.on_batch_end(batch_idx=batch_idx,
loss=loss_per_step[-1],
phase=phase)
self.console_callback.on_batch_end(
batch_idx=batch_idx, loss=loss_per_step[-1], phase=phase
)

def _log_epoch_loss(self, phase, epoch):
loss_per_step = self.training_loss_per_step if phase == 'train' else self.validation_loss_per_step
total_batches = self._total_train_batches if phase == 'train' else self._total_val_batches
epoch_loss = sum(loss_per_step[epoch * total_batches:(epoch + 1) * total_batches]) / total_batches
self.console_callback.on_epoch_end(epoch=epoch, epoch_loss=epoch_loss, phase=phase)
history = self.training_loss_history if phase == 'train' else self.validation_loss_history
loss_per_step = (
self.training_loss_per_step
if phase == "train"
else self.validation_loss_per_step
)
total_batches = (
self._total_train_batches if phase == "train" else self._total_val_batches
)
epoch_loss = (
sum(loss_per_step[epoch * total_batches : (epoch + 1) * total_batches])
/ total_batches
)
self.console_callback.on_epoch_end(
epoch=epoch, epoch_loss=epoch_loss, phase=phase
)
history = (
self.training_loss_history
if phase == "train"
else self.validation_loss_history
)
history.append(epoch_loss)
self.logger.add_scalar(tag=f'{phase}/epoch_loss', scalar_value=epoch_loss, global_step=epoch)
self.logger.add_scalar(
tag=f"{phase}/epoch_loss", scalar_value=epoch_loss, global_step=epoch
)

def resume(self, checkpoint_path):
checkpoint = torch.load(checkpoint_path)
self._restore_state(checkpoint)
start_epoch = checkpoint.get('epoch', 0) + 1
start_epoch = checkpoint.get("epoch", 0) + 1
self._run_epochs(start_epoch)
return self._compile_results()

def _restore_state(self, checkpoint):
self.model.load_state_dict(checkpoint.get('model_state_dict'))
self.optimizer.load_state_dict(checkpoint.get('optimizer_state_dict'))
self.training_loss_history = checkpoint.get('training_loss_history')
self.validation_loss_history = checkpoint.get('validation_loss_history')
self.training_loss_per_step = checkpoint.get('training_loss_per_step', [])
self.validation_loss_per_step = checkpoint.get('validation_loss_per_step', [])
self.model.load_state_dict(checkpoint.get("model_state_dict"))
self.optimizer.load_state_dict(checkpoint.get("optimizer_state_dict"))
self.training_loss_history = checkpoint.get("training_loss_history")
self.validation_loss_history = checkpoint.get("validation_loss_history")
self.training_loss_per_step = checkpoint.get("training_loss_per_step", [])
self.validation_loss_per_step = checkpoint.get("validation_loss_per_step", [])

def _compile_results(self):
return {
'training_loss_history': self.training_loss_history,
'training_loss_per_step': self.training_loss_per_step,
'validation_loss_history': self.validation_loss_history,
'validation_loss_per_step': self.validation_loss_per_step
"training_loss_history": self.training_loss_history,
"training_loss_per_step": self.training_loss_per_step,
"validation_loss_history": self.validation_loss_history,
"validation_loss_per_step": self.validation_loss_per_step,
}

def fit(self):
Expand All @@ -182,57 +224,98 @@ def fit(self):
def training_step(self, train_batch: NTURGBDSample):
x, y, mask = train_batch.x, train_batch.y, train_batch.mask
# Cast them to a torch float32 and move them to the gpu
# TODO: Handle the mask None case
x, y, mask = x.to(torch.float32), y.to(torch.float32), mask.to(torch.float32)
x, y, mask = x.to(self.device), y.to(self.device), mask.to(self.device)
x, y = x.to(torch.float32), y.to(torch.float32)
x, y = x.to(self.device), y.to(self.device)
if mask is not None:
mask = mask.to(torch.float32)
mask = mask.to(self.device)

self.model.train()
out = self.model.training_step(x=x, y=y, mask=mask) # TODO: Make the other models accept a mask as well
loss = out['loss']
outputs = out['out']
out = self.model.training_step(
x=x, y=y, mask=mask
) # TODO: Make the other models accept a mask as well
loss = out["loss"]
outputs = out["out"]
# Calculate the saturation of the tanh output
saturated = (outputs.abs() > 0.95)
saturation_percentage = saturated.sum(dim=(1, 2)).float() / (outputs.size(1) * outputs.size(2)) * 100
saturated = outputs.abs() > 0.95
saturation_percentage = (
saturated.sum(dim=(1, 2)).float()
/ (outputs.size(1) * outputs.size(2))
* 100
)
# Calculate the dead neurons
dead_neurons = (outputs.abs() < 0.05)
dead_neurons_percentage = dead_neurons.sum(dim=(1, 2)).float() / (outputs.size(1) * outputs.size(2)) * 100
dead_neurons = outputs.abs() < 0.05
dead_neurons_percentage = (
dead_neurons.sum(dim=(1, 2)).float()
/ (outputs.size(1) * outputs.size(2))
* 100
)
self.optimizer.zero_grad()
loss.backward()
if self.log_gradient_info:
# Get the gradient flow and update norm ratio
# Get the gradient flow and update norm ratio
self.model.gradient_flow()
self.model.compute_gradient_update_norm(lr=self.optimizer.param_groups[0]['lr'])
self.model.compute_gradient_update_norm(
lr=self.optimizer.param_groups[0]["lr"]
)
grad_hists = self.model.get_gradient_histograms()
# Log the gradient histograms to the logger
if self.logger is not None:
for name, hist in grad_hists.items():
self.logger.add_histogram(tag=f'gradient/hists/{name}_grad_hist', values=hist, global_step=len(self.training_loss_per_step))
self.logger.add_histogram(
tag=f"gradient/hists/{name}_grad_hist",
values=hist,
global_step=len(self.training_loss_per_step),
)

# Log the gradient updates to the logger
if self.logger is not None:
for name, ratio in self.model.gradient_update_ratios.items():
self.logger.add_scalar(tag=f'gradient/{name}_grad_update_norm_ratio', scalar_value=ratio, global_step=len(self.training_loss_per_step))
self.logger.add_scalar(
tag=f"gradient/{name}_grad_update_norm_ratio",
scalar_value=ratio,
global_step=len(self.training_loss_per_step),
)

if self.logger is not None:
self.logger.add_scalar(tag='train/saturation', scalar_value=saturation_percentage.mean().item(), global_step=len(self.training_loss_per_step))
self.logger.add_scalar(tag='train/dead_neurons', scalar_value=dead_neurons_percentage.mean().item(), global_step=len(self.training_loss_per_step))
self.logger.add_scalar(
tag="train/saturation",
scalar_value=saturation_percentage.mean().item(),
global_step=len(self.training_loss_per_step),
)
self.logger.add_scalar(
tag="train/dead_neurons",
scalar_value=dead_neurons_percentage.mean().item(),
global_step=len(self.training_loss_per_step),
)

self.optimizer.step()
# Print the loss
self.training_loss_per_step.append(loss.item())
# Log it to the logger
if self.logger is not None:
self.logger.add_scalar(tag='train/step_loss', scalar_value=loss.item(), global_step=len(self.training_loss_per_step))
self.logger.add_scalar(
tag="train/step_loss",
scalar_value=loss.item(),
global_step=len(self.training_loss_per_step),
)

def validation_step(self, val_batch: NTURGBDSample):
x, y, mask = val_batch.x, val_batch.y, val_batch.mask
# Cast them to a torch float32 and move them to the gpu
x, y, mask = x.to(torch.float32), y.to(torch.float32), mask.to(torch.float32)
x, y, mask = x.to(self.device), y.to(self.device), mask.to(self.device)
x, y = x.to(torch.float32), y.to(torch.float32)
x, y = x.to(self.device), y.to(self.device)
if mask is not None:
mask = mask.to(torch.float32)
mask = mask.to(self.device)
self.model.eval()
out = self.model.validation_step(x=x, y=y, mask=mask)
loss = out['loss']
loss = out["loss"]
self.validation_loss_per_step.append(loss.item())
# Log it to the logger
if self.logger is not None:
self.logger.add_scalar(tag='val/step_loss', scalar_value=loss.item(), global_step=len(self.validation_loss_per_step))

self.logger.add_scalar(
tag="val/step_loss",
scalar_value=loss.item(),
global_step=len(self.validation_loss_per_step),
)

0 comments on commit 33665e1

Please sign in to comment.