Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge to main #27

Merged
merged 10 commits into from
Nov 29, 2023
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
data/skeletons/*
.vscode/
runs/*

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
5 changes: 3 additions & 2 deletions src/skelcast/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,9 @@ def __call__(self, batch) -> NTURGBDSample:
batch_y = []
for sample, _ in batch:
x, y = self.get_windows(sample)
batch_x.append(x)
batch_y.append(y)
chunk_size, context_len, n_bodies, n_joints, n_dims = x.shape
batch_x.append(x.view(chunk_size, context_len, n_bodies * n_joints * n_dims))
batch_y.append(y.view(chunk_size, context_len, n_bodies * n_joints * n_dims))
# Pad the sequences to the maximum sequence length in the batch
batch_x = torch.nn.utils.rnn.pad_sequence(batch_x, batch_first=True)
batch_y = torch.nn.utils.rnn.pad_sequence(batch_y, batch_first=True)
Expand Down
2 changes: 0 additions & 2 deletions src/skelcast/data/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
def get_missing_files(missing_files_dir):
missing_skel_files = []
for fname in os.listdir(missing_files_dir):
print(fname)
with open(osp.join(missing_files_dir, fname)) as f:
for idx, line in enumerate(f):
if idx > 2:
Expand All @@ -31,7 +30,6 @@ def filter_missing(skeleton_files: list, missing_skeleton_names: list):
for f in skeleton_files
if os.path.splitext(os.path.basename(f))[0] not in missing_skeleton_names
]
print(f"Skeleton files after filtering: {len(filtered_skeleton_files)} files left.")
return filtered_skeleton_files


Expand Down
23 changes: 21 additions & 2 deletions src/skelcast/experiments/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from skelcast.data.dataset import NTURGBDCollateFn, NTURGBDSample
from skelcast.callbacks.console import ConsoleCallback
from skelcast.callbacks.checkpoint import CheckpointCallback
from skelcast.logger.base import BaseLogger

class Runner:
def __init__(self,
Expand All @@ -22,13 +23,14 @@ def __init__(self,
n_epochs: int = 10,
device: str = 'cpu',
checkpoint_dir: str = None,
checkpoint_frequency: int = 1) -> None:
checkpoint_frequency: int = 1,
logger: BaseLogger = None) -> None:
self.train_set = train_set
self.val_set = val_set
self.train_batch_size = train_batch_size
self.val_batch_size = val_batch_size
self.block_size = block_size
self._collate_fn = NTURGBDCollateFn(block_size=self.block_size)
self._collate_fn = NTURGBDCollateFn(block_size=self.block_size, is_packed=True)
self.train_loader = DataLoader(dataset=self.train_set, batch_size=self.train_batch_size, shuffle=True, collate_fn=self._collate_fn)
self.val_loader = DataLoader(dataset=self.val_set, batch_size=self.val_batch_size, shuffle=False, collate_fn=self._collate_fn)
self.model = model
Expand Down Expand Up @@ -59,6 +61,7 @@ def __init__(self,
assert os.path.exists(self.checkpoint_dir), f'The designated checkpoint directory `{self.checkpoint_dir}` does not exist.'
self.checkpoint_callback = CheckpointCallback(checkpoint_dir=self.checkpoint_dir,
frequency=self.checkpoint_frequency)
self.logger = logger

def setup(self):
self.model.to(self.device)
Expand All @@ -80,6 +83,7 @@ def fit(self):
epoch_loss = sum(self.training_loss_per_step[epoch * self._total_train_batches:(epoch + 1) * self._total_train_batches]) / self._total_train_batches
self.console_callback.on_epoch_end(epoch=epoch,
epoch_loss=epoch_loss, phase='train')
self.logger.add_scalar(tag='train/epoch_loss', scalar_value=epoch_loss, global_step=epoch)
self.training_loss_history.append(epoch_loss)
for val_batch_idx, val_batch in enumerate(self.val_loader):
self.validation_step(val_batch=val_batch)
Expand All @@ -90,6 +94,7 @@ def fit(self):
self.console_callback.on_epoch_end(epoch=epoch, epoch_loss=epoch_loss, phase='val')
self.validation_loss_history.append(epoch_loss)
self.checkpoint_callback.on_epoch_end(epoch=epoch, runner=self)
self.logger.add_scalar(tag='val/epoch_loss', scalar_value=epoch_loss, global_step=epoch)

return {
'training_loss_history': self.training_loss_history,
Expand All @@ -111,6 +116,9 @@ def training_step(self, train_batch: NTURGBDSample):
self.optimizer.step()
# Print the loss
self.training_loss_per_step.append(loss.item())
# Log it to the logger
if self.logger is not None:
self.logger.add_scalar(tag='train/step_loss', scalar_value=loss.item(), global_step=len(self.training_loss_per_step))

def validation_step(self, val_batch: NTURGBDSample):
x, y = val_batch.x, val_batch.y
Expand All @@ -121,6 +129,9 @@ def validation_step(self, val_batch: NTURGBDSample):
out = self.model.validation_step(x, y)
loss = out['loss']
self.validation_loss_per_step.append(loss.item())
# Log it to the logger
if self.logger is not None:
self.logger.add_scalar(tag='val/step_loss', scalar_value=loss.item(), global_step=len(self.validation_loss_per_step))

def resume(self, checkpoint_path):
"""
Expand Down Expand Up @@ -152,18 +163,26 @@ def resume(self, checkpoint_path):
self.console_callback.on_batch_end(batch_idx=train_batch_idx,
loss=self.training_loss_per_step[-1],
phase='train')
if self.logger is not None:
self.logger.add_scalar(tag='train/step_loss', scalar_value=self.training_loss_per_step[-1], global_step=len(self.training_loss_per_step))
epoch_loss = sum(self.training_loss_per_step[epoch * self._total_train_batches:(epoch + 1) * self._total_train_batches]) / self._total_train_batches
self.console_callback.on_epoch_end(epoch=epoch,
epoch_loss=epoch_loss, phase='train')
self.training_loss_history.append(epoch_loss)
if self.logger is not None:
self.logger.add_scalar(tag='train/epoch_loss', scalar_value=epoch_loss, global_step=epoch)
for val_batch_idx, val_batch in enumerate(self.val_loader):
self.validation_step(val_batch=val_batch)
self.console_callback.on_batch_end(batch_idx=val_batch_idx,
loss=self.validation_loss_per_step[-1],
phase='val')
if self.logger is not None:
self.logger.add_scalar(tag='val/step_loss', scalar_value=self.validation_loss_per_step[-1], global_step=len(self.validation_loss_per_step))
epoch_loss = sum(self.validation_loss_per_step[epoch * self._total_val_batches:(epoch + 1) * self._total_val_batches]) / self._total_val_batches
self.console_callback.on_epoch_end(epoch=epoch, epoch_loss=epoch_loss, phase='val')
self.validation_loss_history.append(epoch_loss)
if self.logger is not None:
self.logger.add_scalar(tag='val/epoch_loss', scalar_value=epoch_loss, global_step=epoch)
self.checkpoint_callback.on_epoch_end(epoch=epoch, runner=self)

return {
Expand Down
Empty file added src/skelcast/logger/__init__.py
Empty file.
23 changes: 23 additions & 0 deletions src/skelcast/logger/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

from abc import ABC, abstractmethod

class BaseLogger(ABC):
@abstractmethod
def add_scalar(self):
pass

@abstractmethod
def add_scalars(self):
pass

@abstractmethod
def add_histogram(self):
pass

@abstractmethod
def add_image(self):
pass

@abstractmethod
def close(self):
pass
25 changes: 25 additions & 0 deletions src/skelcast/logger/tensorboard_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from torch.utils.tensorboard import SummaryWriter
from skelcast.logger.base import BaseLogger


class TensorboardLogger(BaseLogger):
def __init__(self, log_dir):
super().__init__()
self.writer = SummaryWriter(log_dir)

def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
self.writer.add_scalar(tag, scalar_value, global_step, walltime)

def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None):
self.writer.add_scalars(main_tag, tag_scalar_dict, global_step, walltime)

def add_histogram(self, tag, values, global_step=None, bins='tensorflow', walltime=None, max_bins=None):
self.writer.add_histogram(tag, values, global_step, bins, walltime, max_bins)

def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'):
self.writer.add_image(tag, img_tensor, global_step, walltime, dataformats)

# Implement other methods from SummaryWriter as needed

def close(self):
self.writer.close()
30 changes: 30 additions & 0 deletions src/skelcast/models/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
class SkelcastModule(nn.Module, metaclass=abc.ABCMeta):
def __init__(self) -> None:
super(SkelcastModule, self).__init__()
self.gradients = dict()
self.gradient_update_ratios = dict()

@abc.abstractmethod
def predict(self, *args, **kwargs):
Expand All @@ -26,3 +28,31 @@ def validation_step(self, *args, **kwargs):
Implements a validation step of a module
"""
pass

def gradient_flow(self):
"""
Implements the gradient flow step of a module
"""
for name, param in self.named_parameters():
if param.requires_grad and param.grad is not None:
self.gradients[name] = param.grad.clone().detach().cpu().numpy()

def compute_gradient_update_norm(self, lr: float):
"""
Computes the ratio of the parameter update to the parameter norm and stores it to the gradient_update_ratios
dictionary. The gradient update is approximated as the vanilla gradient descent update.

Args:
- lr (float): The optimizer's learning rate
"""
for name, param in self.named_parameters():
if param.requires_grad and param.grad is not None:
self.gradient_update_ratios[name] = (lr * param.grad.norm() / param.norm()).detach().cpu().numpy()

def get_gradient_histograms(self):
"""
Returns the flat gradients of the module's parameters from the gradients dictionary.
"""
return {name: param.grad.clone().view(-1).detach().cpu().numpy() for name, param in self.named_parameters() if
param.requires_grad and param.grad is not None}

36 changes: 21 additions & 15 deletions src/skelcast/models/rnn/lstm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import torch
import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence

from typing import Union

from skelcast.models import SkelcastModule

Expand All @@ -11,7 +14,8 @@ def __init__(self,
batch_first: bool = True,
num_bodies: int = 1,
n_joints: int = 25,
n_dims: int = 3) -> None:
n_dims: int = 3,
reduction: str = 'mean') -> None:
super().__init__()
self.num_bodies = num_bodies
self.n_joints = n_joints
Expand All @@ -22,24 +26,26 @@ def __init__(self,
batch_first=batch_first)
self.linear = nn.Linear(in_features=hidden_size, out_features=input_size)
self.relu = nn.ReLU()
self.criterion = nn.MSELoss(reduction='sum')
self.criterion = nn.MSELoss(reduction=reduction)

def forward(self, x: torch.Tensor,
y: torch.Tensor = None):
assert x.ndim == 5, f'`x` must be a 5-dimensional tensor. Found {x.ndim} dimension(s).'
batch_size, context_size, num_bodies, n_joints, n_dims = x.shape
assert num_bodies == self.num_bodies, f'The number of bodies in the position 2 of the tensor is {num_bodies}, but it should be {self.num_bodies}'
assert n_joints == self.n_joints, f'The number of bodies in the position 3 of the tensor is {n_joints}, but it should be {self.n_joints}'
assert n_dims == self.n_dims, f'The number of bodies in the position 3 of the tensor is {n_dims}, but it should be {self.n_dims}'
def forward(self, x: Union[torch.Tensor, PackedSequence],
y: Union[torch.Tensor, PackedSequence] = None):
if isinstance(x, torch.Tensor):
assert x.ndim == 5, f'`x` must be a 5-dimensional tensor. Found {x.ndim} dimension(s).'
batch_size, context_size, n_dim = x.shape
assert n_dim == self.num_bodies * self.n_joints * self.n_dims, f'The number of bodies in the position 2 of the tensor is {n_dim}, but it should be {self.num_bodies * self.n_joints * self.n_dims}'
else:
assert x.data.ndim == 3, f'`x` must be a 5-dimensional tensor. Found {x.ndim} dimension(s).'
batch_size, context_size, n_dim = x.data.shape
assert n_dim == self.num_bodies * self.n_joints * self.n_dims, f'The number of bodies in the position 2 of the tensor is {n_dim}, but it should be {self.num_bodies * self.n_joints * self.n_dims}'

x = x.view(batch_size, context_size, num_bodies * n_joints * n_dims)
x = self.linear_transform(x)
x = self.linear_transform(x if isinstance(x, torch.Tensor) else x.data)
out, _ = self.lstm(x)
out = self.linear(out)
out = self.relu(out)
out = self.linear(out if isinstance(out, torch.Tensor) else out.data)
out = self.relu(out if isinstance(out, torch.Tensor) else out.data)
if y is not None:
y = y.view(batch_size, context_size, num_bodies * n_joints * n_dims)
loss = self.criterion(out, y)
loss = self.criterion(out if isinstance(out, torch.Tensor) else out.data,
y if isinstance(y, torch.Tensor) else y.data)
return out, loss
return out

Expand Down
Loading