Skip to content

Commit

Permalink
Merge pull request #19 from kaseris/fix/collate_fn
Browse files Browse the repository at this point in the history
Data Loader behaviour change
  • Loading branch information
kaseris authored Nov 27, 2023
2 parents bf9d21b + 18ca0e5 commit bfdec07
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 27 deletions.
77 changes: 50 additions & 27 deletions src/skelcast/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,25 +101,62 @@ def nturbgd_collate_fn_with_overlapping_context_window(batch: List[NTURGBDSample
batch_label = default_collate(batch_label)
return NTURGBDSample(x=batch_x, y=batch_y, label=batch_label)


class NTURGBDCollateFn:
def __init__(self, block_size: int) -> None:
"""
Custom collate function for batched variable-length sequences.
During the __call__ function, we creata `block_size`-long context windows, for each sequence in the batch.
If is_packed is True, we pack the padded sequences, otherwise we return the padded sequences as is.
Args:
- block_size (int): Sequence's context length.
- is_packed (bool): Whether to pack the padded sequence or not.
Returns:
The batched padded sequences ready to be fed to a transformer or an lstm model.
"""
def __init__(self, block_size: int, is_packed: bool = False) -> None:
self.block_size = block_size
self.is_packed = is_packed

def __call__(self, batch) -> NTURGBDSample:
seq_lens = [sample.shape[0] for sample, _ in batch]
labels = [label for _, label in batch]
context = []
target = []
for seq_len, sample in zip(seq_lens, batch):
x, _ = sample
idx = torch.randint(seq_len - self.block_size, (1, ))
context.append(x[idx:idx + self.block_size])
target.append(x[idx + 1:idx + self.block_size + 1])
labels_batch = default_collate(labels)
return NTURGBDSample(x=torch.stack(context),
y=torch.stack(target),
label=labels_batch)

# A dataset's sample has a shape of (seq_len, n_bodies, n_joints, 3)
# We want to create context windows of size `block_size` for each sample
# and stack them together to form a batch of shape (batch_size, block_size, n_bodies, n_joints, 3)
# We also want to create a target tensor of shape (batch_size, n_bodies, n_joints, 3)
# The targets are shifted by 1 timestep to the right, so that the model can predict the next timestep
batch_x = []
batch_y = []
for sample, _ in batch:
x, y = self.get_windows(sample)
batch_x.append(x)
batch_y.append(y)
# Pad the sequences to the maximum sequence length in the batch
batch_x = torch.nn.utils.rnn.pad_sequence(batch_x, batch_first=True)
batch_y = torch.nn.utils.rnn.pad_sequence(batch_y, batch_first=True)
if self.is_packed:
batch_x = torch.nn.utils.rnn.pack_padded_sequence(batch_x, seq_lens, batch_first=True, enforce_sorted=False)
batch_y = torch.nn.utils.rnn.pack_padded_sequence(batch_y, seq_lens, batch_first=True, enforce_sorted=False)
labels = default_collate(labels)
return NTURGBDSample(x=batch_x, y=batch_y, label=labels)

def get_windows(self, x):
seq_len = x.shape[0]
input_windows = []
target_labels = []
for i in range(seq_len - self.block_size):
window = x[i:i + self.block_size, ...]
target_label = x[i + 1:i + self.block_size + 1, ...]
input_windows.append(window)
target_labels.append(target_label)
input_windows = np.array(input_windows)
input_windows_tensor = torch.tensor(input_windows, dtype=torch.float)
target_labels_tensor = torch.tensor(np.array(target_labels), dtype=torch.float)
return input_windows_tensor, target_labels_tensor


class NTURGBDDataset(Dataset):
def __init__(
Expand Down Expand Up @@ -165,20 +202,6 @@ def load_labels(self):
code, label = parts
# Map the code to a tuple of an integer (extracted from the code) and the label
self.labels_dict[code] = (int(code[1:])-1, label)

def get_windows(self, x):
seq_len = x.shape[0]
input_windows = []
target_labels = []
for i in range(seq_len - self.max_context_window):
window = x[i:i + self.max_context_window, ...]
target_label = x[i + self.max_context_window, ...]
input_windows.append(window)
target_labels.append(target_label)
input_windows = np.array(input_windows)
input_windows_tensor = torch.tensor(input_windows, dtype=torch.float)
target_labels_tensor = torch.tensor(np.array(target_labels), dtype=torch.float).unsqueeze(0)
return input_windows_tensor, target_labels_tensor


def __getitem__(self, index) -> torch.Tensor:
Expand Down
Empty file.
102 changes: 102 additions & 0 deletions src/skelcast/models/transformers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import math

import torch
import torch.nn as nn
import torch.nn.functional as F


class PreNorm(nn.Module):
def __init__(self, dim, fn) -> None:
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn

def forward(self, x, **kwargs):
self.fn(self.norm(x), **kwargs)


class MLP(nn.Module):
def __init__(self, dim, embedding_dim, dropout=0.1) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, embedding_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(embedding_dim, dim),
nn.Dropout(dropout)
)

def forward(self, x):
return self.net(x)


class MultiHeadSelfAttention(nn.Module):
def __init__(self, d_model, n_heads=8, inner_head_dim=64, dropout=0.1) -> None:
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.inner_head_dim = inner_head_dim

self.scale = self.inner_head_dim ** -.5

self.per_head_dimensionality = self.inner_head_dim * self.n_heads
self.pre_dropout = nn.Dropout(dropout)
self.to_qkv_chunk = nn.Linear(self.d_model, self.per_head_dimensionality * 3,
bias=False)
self.out_proj = nn.Linear(self.per_head_dimensionality, self.d_model,
bias=False)
self.out_dropout = nn.Dropout(dropout)

def forward(self, x):
qkv = self.to_qkv_chunk(x).chunk(3, dim=-1)
q, k, v = map(lambda t: t.view(-1, self.n_heads, self.inner_head_dim).permute(1, 0, 2), qkv)
attn = F.softmax(torch.matmul(q, k.permute(0, 2, 1) * self.scale), dim=-1)
attn = self.pre_dropout(attn)
attn = attn @ v
attn = attn.permute(1, 0, 2).contiguous()
attn = attn.view(-1, self.n_heads * self.inner_head_dim)
out = self.out_proj(attn)
return self.out_dropout(out)


class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
# Create a long enough 'PE' matrix with position and dimension indexes
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)

pe = pe.unsqueeze(0).transpose(0, 1)
# Registers pe as a buffer that should not be considered a model parameter.
self.register_buffer('pe', pe)

def forward(self, x):
# Adds the positional encoding vector to the input embedding vector
x = x + self.pe[:x.size(0), :]
return x


class Transformer(nn.Module):
def __init__(self, dim, n_blocks, n_heads, dim_head, mlp_dim, dropout) -> None:
super().__init__()
self.blocks = nn.ModuleList([])
for _ in range(n_blocks):
self.blocks.append(
nn.ModuleList([
PreNorm(dim, MultiHeadSelfAttention(dim,
n_heads,
dim_head,
dropout)),
PreNorm(dim, MLP(dim, mlp_dim, dropout))
])
)

def forward(self, x):
for attn, mlp in self.blocks:
out = attn(x) + x
out = mlp(out) + out
return out
18 changes: 18 additions & 0 deletions src/skelcast/models/transformers/transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
import torch.nn as nn

import skelcast.models.transformers.base as base
from skelcast.models import SkelcastModule

class ForecastTransformer(SkelcastModule):
def __init__(self) -> None:
super().__init__()

def training_step(self, *args, **kwargs):
return super().training_step(*args, **kwargs)

def validation_step(self, *args, **kwargs):
return super().validation_step(*args, **kwargs)

def predict(self, *args, **kwargs):
return super().predict(*args, **kwargs)

0 comments on commit bfdec07

Please sign in to comment.