Skip to content

Commit

Permalink
Merge pull request #27 from kaseris/dev
Browse files Browse the repository at this point in the history
Merge to main
  • Loading branch information
kaseris committed Nov 29, 2023
2 parents 2afec3c + 98b6e60 commit b88139a
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 21 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
data/skeletons/*
.vscode/
runs/*

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
5 changes: 3 additions & 2 deletions src/skelcast/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,9 @@ def __call__(self, batch) -> NTURGBDSample:
batch_y = []
for sample, _ in batch:
x, y = self.get_windows(sample)
batch_x.append(x)
batch_y.append(y)
chunk_size, context_len, n_bodies, n_joints, n_dims = x.shape
batch_x.append(x.view(chunk_size, context_len, n_bodies * n_joints * n_dims))
batch_y.append(y.view(chunk_size, context_len, n_bodies * n_joints * n_dims))
# Pad the sequences to the maximum sequence length in the batch
batch_x = torch.nn.utils.rnn.pad_sequence(batch_x, batch_first=True)
batch_y = torch.nn.utils.rnn.pad_sequence(batch_y, batch_first=True)
Expand Down
2 changes: 0 additions & 2 deletions src/skelcast/data/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
def get_missing_files(missing_files_dir):
missing_skel_files = []
for fname in os.listdir(missing_files_dir):
print(fname)
with open(osp.join(missing_files_dir, fname)) as f:
for idx, line in enumerate(f):
if idx > 2:
Expand All @@ -31,7 +30,6 @@ def filter_missing(skeleton_files: list, missing_skeleton_names: list):
for f in skeleton_files
if os.path.splitext(os.path.basename(f))[0] not in missing_skeleton_names
]
print(f"Skeleton files after filtering: {len(filtered_skeleton_files)} files left.")
return filtered_skeleton_files


Expand Down
23 changes: 21 additions & 2 deletions src/skelcast/experiments/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from skelcast.data.dataset import NTURGBDCollateFn, NTURGBDSample
from skelcast.callbacks.console import ConsoleCallback
from skelcast.callbacks.checkpoint import CheckpointCallback
from skelcast.logger.base import BaseLogger

class Runner:
def __init__(self,
Expand All @@ -22,13 +23,14 @@ def __init__(self,
n_epochs: int = 10,
device: str = 'cpu',
checkpoint_dir: str = None,
checkpoint_frequency: int = 1) -> None:
checkpoint_frequency: int = 1,
logger: BaseLogger = None) -> None:
self.train_set = train_set
self.val_set = val_set
self.train_batch_size = train_batch_size
self.val_batch_size = val_batch_size
self.block_size = block_size
self._collate_fn = NTURGBDCollateFn(block_size=self.block_size)
self._collate_fn = NTURGBDCollateFn(block_size=self.block_size, is_packed=True)
self.train_loader = DataLoader(dataset=self.train_set, batch_size=self.train_batch_size, shuffle=True, collate_fn=self._collate_fn)
self.val_loader = DataLoader(dataset=self.val_set, batch_size=self.val_batch_size, shuffle=False, collate_fn=self._collate_fn)
self.model = model
Expand Down Expand Up @@ -59,6 +61,7 @@ def __init__(self,
assert os.path.exists(self.checkpoint_dir), f'The designated checkpoint directory `{self.checkpoint_dir}` does not exist.'
self.checkpoint_callback = CheckpointCallback(checkpoint_dir=self.checkpoint_dir,
frequency=self.checkpoint_frequency)
self.logger = logger

def setup(self):
self.model.to(self.device)
Expand All @@ -80,6 +83,7 @@ def fit(self):
epoch_loss = sum(self.training_loss_per_step[epoch * self._total_train_batches:(epoch + 1) * self._total_train_batches]) / self._total_train_batches
self.console_callback.on_epoch_end(epoch=epoch,
epoch_loss=epoch_loss, phase='train')
self.logger.add_scalar(tag='train/epoch_loss', scalar_value=epoch_loss, global_step=epoch)
self.training_loss_history.append(epoch_loss)
for val_batch_idx, val_batch in enumerate(self.val_loader):
self.validation_step(val_batch=val_batch)
Expand All @@ -90,6 +94,7 @@ def fit(self):
self.console_callback.on_epoch_end(epoch=epoch, epoch_loss=epoch_loss, phase='val')
self.validation_loss_history.append(epoch_loss)
self.checkpoint_callback.on_epoch_end(epoch=epoch, runner=self)
self.logger.add_scalar(tag='val/epoch_loss', scalar_value=epoch_loss, global_step=epoch)

return {
'training_loss_history': self.training_loss_history,
Expand All @@ -111,6 +116,9 @@ def training_step(self, train_batch: NTURGBDSample):
self.optimizer.step()
# Print the loss
self.training_loss_per_step.append(loss.item())
# Log it to the logger
if self.logger is not None:
self.logger.add_scalar(tag='train/step_loss', scalar_value=loss.item(), global_step=len(self.training_loss_per_step))

def validation_step(self, val_batch: NTURGBDSample):
x, y = val_batch.x, val_batch.y
Expand All @@ -121,6 +129,9 @@ def validation_step(self, val_batch: NTURGBDSample):
out = self.model.validation_step(x, y)
loss = out['loss']
self.validation_loss_per_step.append(loss.item())
# Log it to the logger
if self.logger is not None:
self.logger.add_scalar(tag='val/step_loss', scalar_value=loss.item(), global_step=len(self.validation_loss_per_step))

def resume(self, checkpoint_path):
"""
Expand Down Expand Up @@ -152,18 +163,26 @@ def resume(self, checkpoint_path):
self.console_callback.on_batch_end(batch_idx=train_batch_idx,
loss=self.training_loss_per_step[-1],
phase='train')
if self.logger is not None:
self.logger.add_scalar(tag='train/step_loss', scalar_value=self.training_loss_per_step[-1], global_step=len(self.training_loss_per_step))
epoch_loss = sum(self.training_loss_per_step[epoch * self._total_train_batches:(epoch + 1) * self._total_train_batches]) / self._total_train_batches
self.console_callback.on_epoch_end(epoch=epoch,
epoch_loss=epoch_loss, phase='train')
self.training_loss_history.append(epoch_loss)
if self.logger is not None:
self.logger.add_scalar(tag='train/epoch_loss', scalar_value=epoch_loss, global_step=epoch)
for val_batch_idx, val_batch in enumerate(self.val_loader):
self.validation_step(val_batch=val_batch)
self.console_callback.on_batch_end(batch_idx=val_batch_idx,
loss=self.validation_loss_per_step[-1],
phase='val')
if self.logger is not None:
self.logger.add_scalar(tag='val/step_loss', scalar_value=self.validation_loss_per_step[-1], global_step=len(self.validation_loss_per_step))
epoch_loss = sum(self.validation_loss_per_step[epoch * self._total_val_batches:(epoch + 1) * self._total_val_batches]) / self._total_val_batches
self.console_callback.on_epoch_end(epoch=epoch, epoch_loss=epoch_loss, phase='val')
self.validation_loss_history.append(epoch_loss)
if self.logger is not None:
self.logger.add_scalar(tag='val/epoch_loss', scalar_value=epoch_loss, global_step=epoch)
self.checkpoint_callback.on_epoch_end(epoch=epoch, runner=self)

return {
Expand Down
Empty file added src/skelcast/logger/__init__.py
Empty file.
23 changes: 23 additions & 0 deletions src/skelcast/logger/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

from abc import ABC, abstractmethod

class BaseLogger(ABC):
@abstractmethod
def add_scalar(self):
pass

@abstractmethod
def add_scalars(self):
pass

@abstractmethod
def add_histogram(self):
pass

@abstractmethod
def add_image(self):
pass

@abstractmethod
def close(self):
pass
25 changes: 25 additions & 0 deletions src/skelcast/logger/tensorboard_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from torch.utils.tensorboard import SummaryWriter
from skelcast.logger.base import BaseLogger


class TensorboardLogger(BaseLogger):
def __init__(self, log_dir):
super().__init__()
self.writer = SummaryWriter(log_dir)

def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
self.writer.add_scalar(tag, scalar_value, global_step, walltime)

def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None):
self.writer.add_scalars(main_tag, tag_scalar_dict, global_step, walltime)

def add_histogram(self, tag, values, global_step=None, bins='tensorflow', walltime=None, max_bins=None):
self.writer.add_histogram(tag, values, global_step, bins, walltime, max_bins)

def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'):
self.writer.add_image(tag, img_tensor, global_step, walltime, dataformats)

# Implement other methods from SummaryWriter as needed

def close(self):
self.writer.close()
30 changes: 30 additions & 0 deletions src/skelcast/models/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
class SkelcastModule(nn.Module, metaclass=abc.ABCMeta):
def __init__(self) -> None:
super(SkelcastModule, self).__init__()
self.gradients = dict()
self.gradient_update_ratios = dict()

@abc.abstractmethod
def predict(self, *args, **kwargs):
Expand All @@ -26,3 +28,31 @@ def validation_step(self, *args, **kwargs):
Implements a validation step of a module
"""
pass

def gradient_flow(self):
"""
Implements the gradient flow step of a module
"""
for name, param in self.named_parameters():
if param.requires_grad and param.grad is not None:
self.gradients[name] = param.grad.clone().detach().cpu().numpy()

def compute_gradient_update_norm(self, lr: float):
"""
Computes the ratio of the parameter update to the parameter norm and stores it to the gradient_update_ratios
dictionary. The gradient update is approximated as the vanilla gradient descent update.
Args:
- lr (float): The optimizer's learning rate
"""
for name, param in self.named_parameters():
if param.requires_grad and param.grad is not None:
self.gradient_update_ratios[name] = (lr * param.grad.norm() / param.norm()).detach().cpu().numpy()

def get_gradient_histograms(self):
"""
Returns the flat gradients of the module's parameters from the gradients dictionary.
"""
return {name: param.grad.clone().view(-1).detach().cpu().numpy() for name, param in self.named_parameters() if
param.requires_grad and param.grad is not None}

36 changes: 21 additions & 15 deletions src/skelcast/models/rnn/lstm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import torch
import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence

from typing import Union

from skelcast.models import SkelcastModule

Expand All @@ -11,7 +14,8 @@ def __init__(self,
batch_first: bool = True,
num_bodies: int = 1,
n_joints: int = 25,
n_dims: int = 3) -> None:
n_dims: int = 3,
reduction: str = 'mean') -> None:
super().__init__()
self.num_bodies = num_bodies
self.n_joints = n_joints
Expand All @@ -22,24 +26,26 @@ def __init__(self,
batch_first=batch_first)
self.linear = nn.Linear(in_features=hidden_size, out_features=input_size)
self.relu = nn.ReLU()
self.criterion = nn.MSELoss(reduction='sum')
self.criterion = nn.MSELoss(reduction=reduction)

def forward(self, x: torch.Tensor,
y: torch.Tensor = None):
assert x.ndim == 5, f'`x` must be a 5-dimensional tensor. Found {x.ndim} dimension(s).'
batch_size, context_size, num_bodies, n_joints, n_dims = x.shape
assert num_bodies == self.num_bodies, f'The number of bodies in the position 2 of the tensor is {num_bodies}, but it should be {self.num_bodies}'
assert n_joints == self.n_joints, f'The number of bodies in the position 3 of the tensor is {n_joints}, but it should be {self.n_joints}'
assert n_dims == self.n_dims, f'The number of bodies in the position 3 of the tensor is {n_dims}, but it should be {self.n_dims}'
def forward(self, x: Union[torch.Tensor, PackedSequence],
y: Union[torch.Tensor, PackedSequence] = None):
if isinstance(x, torch.Tensor):
assert x.ndim == 5, f'`x` must be a 5-dimensional tensor. Found {x.ndim} dimension(s).'
batch_size, context_size, n_dim = x.shape
assert n_dim == self.num_bodies * self.n_joints * self.n_dims, f'The number of bodies in the position 2 of the tensor is {n_dim}, but it should be {self.num_bodies * self.n_joints * self.n_dims}'
else:
assert x.data.ndim == 3, f'`x` must be a 5-dimensional tensor. Found {x.ndim} dimension(s).'
batch_size, context_size, n_dim = x.data.shape
assert n_dim == self.num_bodies * self.n_joints * self.n_dims, f'The number of bodies in the position 2 of the tensor is {n_dim}, but it should be {self.num_bodies * self.n_joints * self.n_dims}'

x = x.view(batch_size, context_size, num_bodies * n_joints * n_dims)
x = self.linear_transform(x)
x = self.linear_transform(x if isinstance(x, torch.Tensor) else x.data)
out, _ = self.lstm(x)
out = self.linear(out)
out = self.relu(out)
out = self.linear(out if isinstance(out, torch.Tensor) else out.data)
out = self.relu(out if isinstance(out, torch.Tensor) else out.data)
if y is not None:
y = y.view(batch_size, context_size, num_bodies * n_joints * n_dims)
loss = self.criterion(out, y)
loss = self.criterion(out if isinstance(out, torch.Tensor) else out.data,
y if isinstance(y, torch.Tensor) else y.data)
return out, loss
return out

Expand Down

0 comments on commit b88139a

Please sign in to comment.