Skip to content

Commit

Permalink
Merge pull request #46 from kaseris/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
kaseris committed Dec 8, 2023
2 parents 46fab5c + 7a61fdc commit 4db4252
Show file tree
Hide file tree
Showing 4 changed files with 321 additions and 10 deletions.
48 changes: 46 additions & 2 deletions src/skelcast/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import pickle
from dataclasses import dataclass
from typing import Any, Tuple, List
from typing import Any, Tuple, List, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -92,6 +92,7 @@ class NTURGBDSample:
x: torch.tensor
y: torch.tensor
label: Tuple[int, str]
mask: Optional[torch.tensor] = None

def nturbgd_collate_fn_with_overlapping_context_window(batch: List[NTURGBDSample]) -> NTURGBDSample:
# TODO: Normalize each sample individually along its 3 axes
Expand All @@ -102,7 +103,7 @@ def nturbgd_collate_fn_with_overlapping_context_window(batch: List[NTURGBDSample
# batch_x = default_collate(batch_x)
# batch_y = default_collate(batch_y)
batch_label = default_collate(batch_label)
return NTURGBDSample(x=batch_x, y=batch_y, label=batch_label)
return NTURGBDSample(x=batch_x, y=batch_y, label=batch_label, mask=None)


@COLLATE_FUNCS.register_module()
Expand Down Expand Up @@ -161,7 +162,50 @@ def get_windows(self, x):
input_windows_tensor = torch.tensor(input_windows, dtype=torch.float)
target_labels_tensor = torch.tensor(np.array(target_labels), dtype=torch.float)
return input_windows_tensor, target_labels_tensor


@COLLATE_FUNCS.register_module()
class NTURGBDCollateFnWithRandomSampledContextWindow:
"""
Custom collate function for batched variable-length sequences.
During the __call__ function, we creata `block_size`-long context windows, for each sequence in the batch.
If is_packed is True, we pack the padded sequences, otherwise we return the padded sequences as is.
Args:
- block_size (int): Sequence's context length.
- is_packed (bool): Whether to pack the padded sequence or not.
Returns:
The batched padded sequences ready to be fed to a transformer or an lstm model.
"""
def __init__(self, block_size: int) -> None:
self.block_size = block_size

def __call__(self, batch) -> NTURGBDSample:
# Pick a random index for each element of the batch and create a context window of size `block_size`
# around that index
# If the batch element's sequence length is less than `block_size`, then we sample the entire sequence
# Pick the random index using pytorch
seq_lens = [sample.shape[0] for sample, _ in batch]
labels = [label for _, label in batch]
pre_batch = []
for sample, _ in batch:
logging.debug(f'sample.shape: {sample.shape}')
if sample.shape[0] <= self.block_size:
# Sample the entire sequence
logging.debug(f'Detected a sample with a sample length of {sample.shape[0]}')
pre_batch.append(sample)
else:
# Sample a random index
idx = torch.randint(low=0, high=sample.shape[0] - self.block_size, size=(1,)).item()
pre_batch.append(sample[idx:idx + self.block_size, ...])
# Pad the sequences to the maximum sequence length in the batch
batch_x = torch.nn.utils.rnn.pad_sequence(pre_batch, batch_first=True)
# Generate masks
masks = torch.nn.utils.rnn.pack_sequence([torch.ones(seq_len) for seq_len in seq_lens], enforce_sorted=False).to(torch.float32)
return NTURGBDSample(x=batch_x, y=batch_x, label=labels, mask=masks)


@DATASETS.register_module()
class NTURGBDDataset(Dataset):
Expand Down
6 changes: 5 additions & 1 deletion src/skelcast/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from skelcast.core.registry import Registry

MODELS = Registry()
ENCODERS = Registry()
DECODERS = Registry()

from .rnn.lstm import SimpleLSTMRegressor
from .transformers.transformer import ForecastTransformer
from .transformers.transformer import ForecastTransformer
from .rnn.pvred import PositionalVelocityRecurrentEncoderDecoder
from .rnn.pvred import Encoder, Decoder
246 changes: 246 additions & 0 deletions src/skelcast/models/rnn/pvred.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
import torch
import torch.nn as nn

from skelcast.models import MODELS, ENCODERS, DECODERS
from skelcast.models.module import SkelcastModule
from skelcast.models.transformers.base import PositionalEncoding


@ENCODERS.register_module()
class Encoder(nn.Module):
def __init__(self, rnn_type: str = 'rnn',
input_dim: int = 75,
hidden_dim: int = 256,
batch_first: bool = True,
dropout: float = 0.2,
use_residual: bool = True) -> None:
super().__init__()
assert rnn_type in ['lstm', 'gru'], f'rnn_type must be one of rnn, lstm, gru, got {rnn_type}'
self.rnn_type = rnn_type
self.batch_first = batch_first
self.use_residual = use_residual

if self.rnn_type == 'lstm':
self.rnn = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, batch_first=batch_first)
elif self.rnn_type == 'gru':
self.rnn = nn.GRU(input_size=input_dim, hidden_size=hidden_dim, batch_first=batch_first)
self.linear = nn.Linear(hidden_dim, input_dim)
self.dropout = nn.Dropout(dropout)

def forward(self, x: torch.Tensor) -> torch.Tensor:
out, hidden = self.rnn(x)
out = self.dropout(out)
out = self.linear(out)
if self.use_residual:
out = out + x
return out, hidden


@DECODERS.register_module()
class Decoder(nn.Module):
def __init__(self,rnn_type: str = 'rnn',
input_dim: int = 75,
hidden_dim: int = 256,
batch_first: bool = True,
dropout: float = 0.2,
use_residual: bool = True) -> None:
super().__init__()
assert rnn_type in ['lstm', 'gru'], f'rnn_type must be one of rnn, lstm, gru, got {rnn_type}'
self.rnn_type = rnn_type
self.batch_first = batch_first
self.use_residual = use_residual

if self.rnn_type == 'lstm':
self.rnn = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, batch_first=batch_first)
elif self.rnn_type == 'gru':
self.rnn = nn.GRU(input_size=input_dim, hidden_size=hidden_dim, batch_first=batch_first)
self.linear = nn.Linear(hidden_dim, input_dim)
self.dropout = nn.Dropout(dropout)

def forward(self, x: torch.Tensor, hidden: torch.Tensor = None, timesteps_to_predict: int = 5) -> torch.Tensor:
predictions = []
for _ in range(timesteps_to_predict):
out, hidden = self.rnn(x, hidden)
out = self.dropout(out)
out = self.linear(out)
if self.use_residual:
out = out + x
predictions.append(out)
x = out
return torch.cat(predictions, dim=1)


@MODELS.register_module()
class PositionalVelocityRecurrentEncoderDecoder(SkelcastModule):
"""
Positional Velocity Recurrent Encoder Decoder (PVRED) model.
The model was proposed in the paper: https://arxiv.org/abs/1906.06514
The model consists of an encoder and a decoder. The encoder is a recurrent neural network (RNN)
that encodes the input sequence into a latent representation. The decoder is also an RNN that
decodes the latent representation into a sequence of the same length as the input sequence.
The authors of PVRED introduce the idea of positional encoding. Below, we describe how can the
positional encoding be applied to the input sequence.
Three ways of taking into account the positional encoding:
1. Concatenate the positional encoding to the input
2. Add the positional encoding to the input
3. NOP (no positional encoding) - if pos_enc is None
Args:
---
- `input_dim` (`int`): Input dimension of the model
- `enc_hidden_dim` (`int`): Hidden dimension of the encoder
- `dec_hidden_dim` (`int`): Hidden dimension of the decoder
- `enc_type` (`str`): Type of the encoder, one of lstm, gru
- `dec_type` (`str`): Type of the decoder, one of lstm, gru
- `include_velocity` (`bool`): Flag to indicate whether to include the velocity of the input
- `pos_enc` (`str`): Type of the positional encoding, one of concat, add, None
- `loss_fn` (`torch.nn.Module`): Loss function to be used for training
- `batch_first` (`bool`): Flag to indicate whether the batch dimension is the first dimension
- `std_thresh` (`float`): Threshold for the standard deviation of the input
- `use_std_mask` (`bool`): Flag to indicate whether to use the standard deviation mask
- `use_padded_len_mask` (`bool`): Flag to indicate whether to use the padded length mask
- `observe_until` (`int`): The number of frames to observe before predicting the future
Returns:
---
- `dec_out` (`torch.Tensor`): Output of the decoder
- `loss` (`torch.Tensor`): Loss of the model
Examples:
---
```python
import torch
from skelcast.models.rnn.pvred import PositionalVelocityRecurrentEncoderDecoder
model = PositionalVelocityRecurrentEncoderDecoder(input_dim = 75, observe_until=80)
x = torch.randn(32, 100, 75)
dec_out, loss = model(x)
dec_out.shape
>>> torch.Size([32, 20, 75])
```
"""
def __init__(self, input_dim: int, enc_hidden_dim: int = 64,
dec_hidden_dim: int = 64,
enc_type: str = 'lstm',
dec_type: str = 'lstm',
include_velocity: bool = False,
pos_enc: str = None,
loss_fn: nn.Module = None,
batch_first: bool = True,
std_thresh: float = 1e-4,
use_std_mask: bool = False,
use_padded_len_mask: bool = False,
observe_until: int = 30) -> None:
assert pos_enc in ['concat', 'add', None], f'pos_enc must be one of concat, add, None, got {pos_enc}'
assert isinstance(loss_fn, nn.Module), f'loss_fn must be an instance of torch.nn.Module, got {type(loss_fn)}'
assert enc_type in ['lstm', 'gru'], f'enc_type must be one of lstm, gru, got {enc_type}'
assert dec_type in ['lstm', 'gru'], f'dec_type must be one of lstm, gru, got {dec_type}'
super().__init__()
self.input_dim = input_dim
self.enc_hidden_dim = enc_hidden_dim
self.dec_hidden_dim = dec_hidden_dim
self.enc_type = enc_type
self.batch_first = batch_first
self.std_thresh = std_thresh
self.use_std_mask = use_std_mask
self.use_padded_len_mask = use_padded_len_mask
self.observe_until = observe_until

self.include_velocity = include_velocity
if self.include_velocity:
# Double the input dimension because we are concatenating the xyz velocities to the input
self.input_dim = self.input_dim * 2

self.pos_enc_method = pos_enc
if self.pos_enc_method == 'concat':
self.pos_enc = PositionalEncoding(d_model=input_dim, mode='concat')
self.input_dim = self.input_dim + input_dim
elif self.pos_enc_method == 'add':
self.pos_enc = PositionalEncoding(d_model=input_dim, mode='add')
else:
self.pos_enc = None

self.loss_fn = loss_fn

# Build encoder
# TODO: Convert them to registry-type build
self.encoder = Encoder(rnn_type=enc_type, input_dim=self.input_dim,
hidden_dim=enc_hidden_dim, batch_first=batch_first)
# Build decoder
self.decoder = Decoder(rnn_type=dec_type, input_dim=self.input_dim,
hidden_dim=dec_hidden_dim, batch_first=batch_first)


def forward(self, x: torch.Tensor, masks: torch.Tensor = None) -> torch.Tensor:
# Calculate the velocity if the include_velocity flag is true
if self.include_velocity:
vel_inp = self._calculate_velocity(x)
# Concatenate the velocity to the input and the targets
x = torch.cat([x, vel_inp], dim=-1)

# If the pos_enc is not None, apply the positional encoding, dependent on the pos_enc_method

if self.pos_enc is not None:
x = self.pos_enc(x)

encoder_input, decoder_initial_value, targets = x[:, :self.observe_until, :], x[:, self.observe_until, :], x[:, self.observe_until:, :]
mask_pred = torch.std(x, dim=1) > self.std_thresh if self.batch_first else torch.std(x, dim=0) > self.std_thresh
# Encode the input
enc_out, enc_hidden = self.encoder(encoder_input)
# Decode the output
dec_out = self.decoder(decoder_initial_value.unsqueeze(1), enc_hidden, timesteps_to_predict=targets.shape[1])

# The decoder's output should have a shape of (batch_size, seq_len, input_dim)
assert dec_out.shape == targets.shape, f'dec_out.shape must be equal to targets.shape, got {dec_out.shape} and {targets.shape}'
# Apply the padded length masks to the prediction
if self.use_padded_len_mask:
dec_out = dec_out * masks.float()

# Apply the std masks to the prediction
if self.use_std_mask:
dec_out = dec_out * mask_pred.float()

# Calculate the loss
loss = self.loss_fn(dec_out, targets)

return dec_out, loss

def _calculate_velocity(self, x: torch.Tensor) -> torch.Tensor:
"""
Calculate the velocity of the input tensor
Args:
---
- `x` (`torch.Tensor`): Input tensor of shape `(batch_size, seq_len, input_dim)`
Returns:
---
- Velocity tensor of shape (batch_size, seq_len, input_dim)
"""
# Calculate the velocity
velocity = torch.zeros_like(x)
velocity[:, 1:, :] = x[:, 1:, :] - x[:, :-1, :]
return velocity

def training_step(self, x: torch.Tensor, y: torch.Tensor) -> dict:
self.encoder.train()
self.decoder.train()
# Forward pass
dec_out, loss = self(x, y)
return {'loss': loss, 'out': dec_out}

@torch.no_grad()
def validation_step(self, *args, **kwargs) -> dict:
self.encoder.eval()
self.decoder.eval()
return self.training_step(*args, **kwargs)


@torch.no_grad()
def predict(self, *args, **kwargs):
self.encoder.eval()
self.decoder.eval()

31 changes: 24 additions & 7 deletions src/skelcast/models/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,40 @@ def forward(self, x):


class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
def __init__(self, d_model, max_len=5000, mode='add'):
super(PositionalEncoding, self).__init__()
self.mode = mode

# Create a long enough 'PE' matrix with position and dimension indexes
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

# Adjust div_term calculation
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))


# Assign sine to even and cosine to odd indices
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)

pe = pe.unsqueeze(0).transpose(0, 1)
# Registers pe as a buffer that should not be considered a model parameter.
# Adjust cosine assignment for odd d_model
if d_model % 2 == 0:
pe[:, 1::2] = torch.cos(position * div_term)
else:
pe[:, 1::2] = torch.cos(position * div_term[:-1])
last_cos = torch.cos(position * div_term[-1])
pe[:, -1] = last_cos.squeeze()

pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)

def forward(self, x):
# Adds the positional encoding vector to the input embedding vector
x = x + self.pe[:x.size(0), :]
if self.mode == 'add':
# Adds the positional encoding vector to the input embedding vector
x = x + self.pe[:x.size(1), :].repeat(x.size(0), 1, 1)
elif self.mode == 'concat':
# Concatenates the positional encoding vector with the input embedding vector
x = torch.cat((x, self.pe[:, :x.size(1), :].repeat(x.size(0), 1, 1)), dim=-1)
else:
raise ValueError("Invalid mode. Choose 'add' or 'concat'.")
return x


Expand Down

0 comments on commit 4db4252

Please sign in to comment.