diff --git a/.gitignore b/.gitignore index 3acf2b8..51b0779 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ data/skeletons/* .vscode/ +runs/* # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/src/skelcast/data/dataset.py b/src/skelcast/data/dataset.py index af506bd..878340e 100644 --- a/src/skelcast/data/dataset.py +++ b/src/skelcast/data/dataset.py @@ -132,8 +132,9 @@ def __call__(self, batch) -> NTURGBDSample: batch_y = [] for sample, _ in batch: x, y = self.get_windows(sample) - batch_x.append(x) - batch_y.append(y) + chunk_size, context_len, n_bodies, n_joints, n_dims = x.shape + batch_x.append(x.view(chunk_size, context_len, n_bodies * n_joints * n_dims)) + batch_y.append(y.view(chunk_size, context_len, n_bodies * n_joints * n_dims)) # Pad the sequences to the maximum sequence length in the batch batch_x = torch.nn.utils.rnn.pad_sequence(batch_x, batch_first=True) batch_y = torch.nn.utils.rnn.pad_sequence(batch_y, batch_first=True) diff --git a/src/skelcast/data/prepare_data.py b/src/skelcast/data/prepare_data.py index 7aae228..69867fe 100644 --- a/src/skelcast/data/prepare_data.py +++ b/src/skelcast/data/prepare_data.py @@ -5,7 +5,6 @@ def get_missing_files(missing_files_dir): missing_skel_files = [] for fname in os.listdir(missing_files_dir): - print(fname) with open(osp.join(missing_files_dir, fname)) as f: for idx, line in enumerate(f): if idx > 2: @@ -31,7 +30,6 @@ def filter_missing(skeleton_files: list, missing_skeleton_names: list): for f in skeleton_files if os.path.splitext(os.path.basename(f))[0] not in missing_skeleton_names ] - print(f"Skeleton files after filtering: {len(filtered_skeleton_files)} files left.") return filtered_skeleton_files diff --git a/src/skelcast/experiments/runner.py b/src/skelcast/experiments/runner.py index 5826ee4..93eb689 100644 --- a/src/skelcast/experiments/runner.py +++ b/src/skelcast/experiments/runner.py @@ -9,6 +9,7 @@ from skelcast.data.dataset import NTURGBDCollateFn, NTURGBDSample from skelcast.callbacks.console import ConsoleCallback from skelcast.callbacks.checkpoint import CheckpointCallback +from skelcast.logger.base import BaseLogger class Runner: def __init__(self, @@ -22,13 +23,14 @@ def __init__(self, n_epochs: int = 10, device: str = 'cpu', checkpoint_dir: str = None, - checkpoint_frequency: int = 1) -> None: + checkpoint_frequency: int = 1, + logger: BaseLogger = None) -> None: self.train_set = train_set self.val_set = val_set self.train_batch_size = train_batch_size self.val_batch_size = val_batch_size self.block_size = block_size - self._collate_fn = NTURGBDCollateFn(block_size=self.block_size) + self._collate_fn = NTURGBDCollateFn(block_size=self.block_size, is_packed=True) self.train_loader = DataLoader(dataset=self.train_set, batch_size=self.train_batch_size, shuffle=True, collate_fn=self._collate_fn) self.val_loader = DataLoader(dataset=self.val_set, batch_size=self.val_batch_size, shuffle=False, collate_fn=self._collate_fn) self.model = model @@ -59,6 +61,7 @@ def __init__(self, assert os.path.exists(self.checkpoint_dir), f'The designated checkpoint directory `{self.checkpoint_dir}` does not exist.' self.checkpoint_callback = CheckpointCallback(checkpoint_dir=self.checkpoint_dir, frequency=self.checkpoint_frequency) + self.logger = logger def setup(self): self.model.to(self.device) @@ -80,6 +83,7 @@ def fit(self): epoch_loss = sum(self.training_loss_per_step[epoch * self._total_train_batches:(epoch + 1) * self._total_train_batches]) / self._total_train_batches self.console_callback.on_epoch_end(epoch=epoch, epoch_loss=epoch_loss, phase='train') + self.logger.add_scalar(tag='train/epoch_loss', scalar_value=epoch_loss, global_step=epoch) self.training_loss_history.append(epoch_loss) for val_batch_idx, val_batch in enumerate(self.val_loader): self.validation_step(val_batch=val_batch) @@ -90,6 +94,7 @@ def fit(self): self.console_callback.on_epoch_end(epoch=epoch, epoch_loss=epoch_loss, phase='val') self.validation_loss_history.append(epoch_loss) self.checkpoint_callback.on_epoch_end(epoch=epoch, runner=self) + self.logger.add_scalar(tag='val/epoch_loss', scalar_value=epoch_loss, global_step=epoch) return { 'training_loss_history': self.training_loss_history, @@ -111,6 +116,9 @@ def training_step(self, train_batch: NTURGBDSample): self.optimizer.step() # Print the loss self.training_loss_per_step.append(loss.item()) + # Log it to the logger + if self.logger is not None: + self.logger.add_scalar(tag='train/step_loss', scalar_value=loss.item(), global_step=len(self.training_loss_per_step)) def validation_step(self, val_batch: NTURGBDSample): x, y = val_batch.x, val_batch.y @@ -121,6 +129,9 @@ def validation_step(self, val_batch: NTURGBDSample): out = self.model.validation_step(x, y) loss = out['loss'] self.validation_loss_per_step.append(loss.item()) + # Log it to the logger + if self.logger is not None: + self.logger.add_scalar(tag='val/step_loss', scalar_value=loss.item(), global_step=len(self.validation_loss_per_step)) def resume(self, checkpoint_path): """ @@ -152,18 +163,26 @@ def resume(self, checkpoint_path): self.console_callback.on_batch_end(batch_idx=train_batch_idx, loss=self.training_loss_per_step[-1], phase='train') + if self.logger is not None: + self.logger.add_scalar(tag='train/step_loss', scalar_value=self.training_loss_per_step[-1], global_step=len(self.training_loss_per_step)) epoch_loss = sum(self.training_loss_per_step[epoch * self._total_train_batches:(epoch + 1) * self._total_train_batches]) / self._total_train_batches self.console_callback.on_epoch_end(epoch=epoch, epoch_loss=epoch_loss, phase='train') self.training_loss_history.append(epoch_loss) + if self.logger is not None: + self.logger.add_scalar(tag='train/epoch_loss', scalar_value=epoch_loss, global_step=epoch) for val_batch_idx, val_batch in enumerate(self.val_loader): self.validation_step(val_batch=val_batch) self.console_callback.on_batch_end(batch_idx=val_batch_idx, loss=self.validation_loss_per_step[-1], phase='val') + if self.logger is not None: + self.logger.add_scalar(tag='val/step_loss', scalar_value=self.validation_loss_per_step[-1], global_step=len(self.validation_loss_per_step)) epoch_loss = sum(self.validation_loss_per_step[epoch * self._total_val_batches:(epoch + 1) * self._total_val_batches]) / self._total_val_batches self.console_callback.on_epoch_end(epoch=epoch, epoch_loss=epoch_loss, phase='val') self.validation_loss_history.append(epoch_loss) + if self.logger is not None: + self.logger.add_scalar(tag='val/epoch_loss', scalar_value=epoch_loss, global_step=epoch) self.checkpoint_callback.on_epoch_end(epoch=epoch, runner=self) return { diff --git a/src/skelcast/logger/__init__.py b/src/skelcast/logger/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/skelcast/logger/base.py b/src/skelcast/logger/base.py new file mode 100644 index 0000000..d11ccff --- /dev/null +++ b/src/skelcast/logger/base.py @@ -0,0 +1,23 @@ + +from abc import ABC, abstractmethod + +class BaseLogger(ABC): + @abstractmethod + def add_scalar(self): + pass + + @abstractmethod + def add_scalars(self): + pass + + @abstractmethod + def add_histogram(self): + pass + + @abstractmethod + def add_image(self): + pass + + @abstractmethod + def close(self): + pass diff --git a/src/skelcast/logger/tensorboard_logger.py b/src/skelcast/logger/tensorboard_logger.py new file mode 100644 index 0000000..1ba2908 --- /dev/null +++ b/src/skelcast/logger/tensorboard_logger.py @@ -0,0 +1,25 @@ +from torch.utils.tensorboard import SummaryWriter +from skelcast.logger.base import BaseLogger + + +class TensorboardLogger(BaseLogger): + def __init__(self, log_dir): + super().__init__() + self.writer = SummaryWriter(log_dir) + + def add_scalar(self, tag, scalar_value, global_step=None, walltime=None): + self.writer.add_scalar(tag, scalar_value, global_step, walltime) + + def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None): + self.writer.add_scalars(main_tag, tag_scalar_dict, global_step, walltime) + + def add_histogram(self, tag, values, global_step=None, bins='tensorflow', walltime=None, max_bins=None): + self.writer.add_histogram(tag, values, global_step, bins, walltime, max_bins) + + def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'): + self.writer.add_image(tag, img_tensor, global_step, walltime, dataformats) + + # Implement other methods from SummaryWriter as needed + + def close(self): + self.writer.close() diff --git a/src/skelcast/models/module.py b/src/skelcast/models/module.py index 840e2f9..622e22a 100644 --- a/src/skelcast/models/module.py +++ b/src/skelcast/models/module.py @@ -5,6 +5,8 @@ class SkelcastModule(nn.Module, metaclass=abc.ABCMeta): def __init__(self) -> None: super(SkelcastModule, self).__init__() + self.gradients = dict() + self.gradient_update_ratios = dict() @abc.abstractmethod def predict(self, *args, **kwargs): @@ -26,3 +28,31 @@ def validation_step(self, *args, **kwargs): Implements a validation step of a module """ pass + + def gradient_flow(self): + """ + Implements the gradient flow step of a module + """ + for name, param in self.named_parameters(): + if param.requires_grad and param.grad is not None: + self.gradients[name] = param.grad.clone().detach().cpu().numpy() + + def compute_gradient_update_norm(self, lr: float): + """ + Computes the ratio of the parameter update to the parameter norm and stores it to the gradient_update_ratios + dictionary. The gradient update is approximated as the vanilla gradient descent update. + + Args: + - lr (float): The optimizer's learning rate + """ + for name, param in self.named_parameters(): + if param.requires_grad and param.grad is not None: + self.gradient_update_ratios[name] = (lr * param.grad.norm() / param.norm()).detach().cpu().numpy() + + def get_gradient_histograms(self): + """ + Returns the flat gradients of the module's parameters from the gradients dictionary. + """ + return {name: param.grad.clone().view(-1).detach().cpu().numpy() for name, param in self.named_parameters() if + param.requires_grad and param.grad is not None} + \ No newline at end of file diff --git a/src/skelcast/models/rnn/lstm.py b/src/skelcast/models/rnn/lstm.py index 9ae408e..bd190ce 100644 --- a/src/skelcast/models/rnn/lstm.py +++ b/src/skelcast/models/rnn/lstm.py @@ -1,5 +1,8 @@ import torch import torch.nn as nn +from torch.nn.utils.rnn import PackedSequence + +from typing import Union from skelcast.models import SkelcastModule @@ -11,7 +14,8 @@ def __init__(self, batch_first: bool = True, num_bodies: int = 1, n_joints: int = 25, - n_dims: int = 3) -> None: + n_dims: int = 3, + reduction: str = 'mean') -> None: super().__init__() self.num_bodies = num_bodies self.n_joints = n_joints @@ -22,24 +26,26 @@ def __init__(self, batch_first=batch_first) self.linear = nn.Linear(in_features=hidden_size, out_features=input_size) self.relu = nn.ReLU() - self.criterion = nn.MSELoss(reduction='sum') + self.criterion = nn.MSELoss(reduction=reduction) - def forward(self, x: torch.Tensor, - y: torch.Tensor = None): - assert x.ndim == 5, f'`x` must be a 5-dimensional tensor. Found {x.ndim} dimension(s).' - batch_size, context_size, num_bodies, n_joints, n_dims = x.shape - assert num_bodies == self.num_bodies, f'The number of bodies in the position 2 of the tensor is {num_bodies}, but it should be {self.num_bodies}' - assert n_joints == self.n_joints, f'The number of bodies in the position 3 of the tensor is {n_joints}, but it should be {self.n_joints}' - assert n_dims == self.n_dims, f'The number of bodies in the position 3 of the tensor is {n_dims}, but it should be {self.n_dims}' + def forward(self, x: Union[torch.Tensor, PackedSequence], + y: Union[torch.Tensor, PackedSequence] = None): + if isinstance(x, torch.Tensor): + assert x.ndim == 5, f'`x` must be a 5-dimensional tensor. Found {x.ndim} dimension(s).' + batch_size, context_size, n_dim = x.shape + assert n_dim == self.num_bodies * self.n_joints * self.n_dims, f'The number of bodies in the position 2 of the tensor is {n_dim}, but it should be {self.num_bodies * self.n_joints * self.n_dims}' + else: + assert x.data.ndim == 3, f'`x` must be a 5-dimensional tensor. Found {x.ndim} dimension(s).' + batch_size, context_size, n_dim = x.data.shape + assert n_dim == self.num_bodies * self.n_joints * self.n_dims, f'The number of bodies in the position 2 of the tensor is {n_dim}, but it should be {self.num_bodies * self.n_joints * self.n_dims}' - x = x.view(batch_size, context_size, num_bodies * n_joints * n_dims) - x = self.linear_transform(x) + x = self.linear_transform(x if isinstance(x, torch.Tensor) else x.data) out, _ = self.lstm(x) - out = self.linear(out) - out = self.relu(out) + out = self.linear(out if isinstance(out, torch.Tensor) else out.data) + out = self.relu(out if isinstance(out, torch.Tensor) else out.data) if y is not None: - y = y.view(batch_size, context_size, num_bodies * n_joints * n_dims) - loss = self.criterion(out, y) + loss = self.criterion(out if isinstance(out, torch.Tensor) else out.data, + y if isinstance(y, torch.Tensor) else y.data) return out, loss return out