Skip to content

Commit

Permalink
Merge pull request #59 from kaseris/models/sttf
Browse files Browse the repository at this point in the history
Models/sttf
  • Loading branch information
kaseris committed Jan 4, 2024
2 parents 2c98825 + 02c809f commit be590f5
Show file tree
Hide file tree
Showing 3 changed files with 291 additions and 0 deletions.
62 changes: 62 additions & 0 deletions configs/sttf_base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
dataset:
name: NTURGBDDataset
args:
data_directory: /home/kaseris/Documents/mount/data_ntu_rgbd
label_file: /home/kaseris/Documents/dev/skelcast/data/labels.txt
missing_files_dir: /home/kaseris/Documents/dev/skelcast/data/missing
max_context_window: 10
max_number_of_bodies: 1
max_duration: 300
n_joints: 25
cache_file: /home/kaseris/Documents/mount/dataset_cache.pkl

transforms:
- name: MinMaxScaleTransform
args:
feature_scale: [0.0, 1.0]
- name: CartToQuaternionTransform
args:
parents: null

loss:
name: SmoothL1Loss
args:
reduction: mean
beta: 0.01

collate_fn:
name: NTURGBDCollateFnWithRandomSampledContextWindow
args:
block_size: 10

logger:
name: TensorboardLogger
args:
log_dir: runs

optimizer:
name: AdamW
args:
lr: 0.0001
weight_decay: 0.0001

model:
name: SpatioTemporalTransformer
args:
n_joints: 25
d_model: 256
n_blocks: 8
n_heads: 8
d_head: 16
mlp_dim: 512
dropout: 0.5

runner:
name: Runner
args:
train_batch_size: 1024
val_batch_size: 1024
block_size: 8
log_gradient_info: true
device: cuda
n_epochs: 100
1 change: 1 addition & 0 deletions src/skelcast/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from .rnn.lstm import SimpleLSTMRegressor
from .transformers.transformer import ForecastTransformer
from .transformers.sttf import SpatioTemporalTransformer
from .rnn.pvred import PositionalVelocityRecurrentEncoderDecoder
from .rnn.pvred import Encoder, Decoder
from .cnn.unet import Unet
228 changes: 228 additions & 0 deletions src/skelcast/models/transformers/sttf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from skelcast.models.module import SkelcastModule
from skelcast.models.transformers.base import PositionalEncoding
from skelcast.models import MODELS


class TemporalMultiHeadAttentionBlock(nn.Module):
'''
Args:
n_heads `int`: Number of heads
d_model `int`: The input dimensionality
d_head `int`: The per-head dimensionality
'''
def __init__(self, n_heads, d_head, n_joints, d_model, dropout=0.1, debug=False):
super().__init__()
self.n_heads = n_heads
self.d_model = d_model
self.d_head = d_head
self.n_joints = n_joints
self.debug = debug

self.q = nn.Linear(in_features=d_model * n_joints, out_features=n_joints * d_head * n_heads, bias=False)
self.k = nn.Linear(in_features=d_model * n_joints, out_features=n_joints * d_head * n_heads, bias=False)
self.v = nn.Linear(in_features=d_model * n_joints, out_features=n_joints * d_head * n_heads, bias=False)

self.back_proj = nn.Linear(in_features=d_head * n_heads, out_features=d_model, bias=False) # Project back to original dimensionality.
self.dropout = nn.Dropout(dropout)

def forward(self, x: torch.Tensor):
batch_size, seq_len, n_joints, dims = x.shape
assert n_joints == self.n_joints, f'Expected n_joints: {self.n_joints}. Got: n_joints: {n_joints}'
assert dims == self.d_model, f'Expected d_model: {self.d_model}. Got: d_model: {dims}'
x = x.view(batch_size, seq_len, self.n_joints * self.d_model)
q_proj = self.q(x)
k_proj = self.k(x)
v_proj = self.v(x)
mask = self.get_mask(seq_len, batch_size)
attn_prod_ = torch.bmm(q_proj, k_proj.permute(0, 2, 1)) * (self.d_model) ** -0.5

attn_temporal = F.softmax(attn_prod_ + mask, dim=-1)
attn = attn_temporal @ v_proj
out = self.back_proj(attn.view(batch_size, seq_len, n_joints, -1))
out = self.dropout(out)
if self.debug:
return out.view(batch_size, seq_len, n_joints, dims), attn, attn_temporal
return out.view(batch_size, seq_len, n_joints, dims)

def get_mask(self, seq_len, batch_size):
mask = torch.triu(torch.ones((seq_len, seq_len)) * float('-inf'), diagonal=1)
return mask.repeat(batch_size, 1, 1)


class SpatialMultiHeadAttentionBlock(nn.Module):
def __init__(self, n_heads, d_head, n_joints, d_model, dropout=0.1, debug=False):
super().__init__()
self.n_heads = n_heads
self.d_model = d_model
self.d_head = d_head
self.n_joints = n_joints
self.debug = debug

# The query projection treats all joints differently, so stays as is
self.q_spatial = nn.Linear(in_features=n_joints * d_model, out_features=n_joints * n_heads * d_head, bias=False)
# The key and value projections are shared across all joints and time steps
self.k_spatial = nn.Linear(in_features=d_model, out_features=n_heads * d_head, bias=False)
self.v_spatial = nn.Linear(in_features=d_model, out_features=n_heads * d_head, bias=False)
self.back_proj_spatial = nn.Linear(in_features=n_heads * d_head, out_features=d_model, bias=False)
self.dropout = nn.Dropout(dropout)

def forward(self, x):
batch_size, seq_len, n_joints, dims = x.shape
assert n_joints == self.n_joints, f'Expected n_joints: {self.n_joints}. Got: n_joints: {n_joints}'
assert dims == self.d_model, f'Expected d_model: {self.d_model}. Got: d_model: {dims}'
# Treat all joints individually to compute the query
q_proj_spatial = self.q_spatial(x.view(batch_size, seq_len, n_joints * self.d_model))
q_proj_spatial = q_proj_spatial.view(batch_size, n_joints, self.n_heads * self.d_head * seq_len)
k_proj_spatial = self.k_spatial(x).view(batch_size, n_joints, self.n_heads * self.d_head * seq_len)
v_proj_spatial = self.v_spatial(x).view(batch_size, n_joints, self.n_heads * self.d_head * seq_len)

attn_prod_ = torch.bmm(q_proj_spatial, k_proj_spatial.permute(0, 2, 1)) * (self.d_model) ** -.5

attn_spatial = F.softmax(attn_prod_, dim=-1)

mha_attn_spatial = attn_spatial @ v_proj_spatial
spatial_attn_out = self.back_proj_spatial(mha_attn_spatial.view(batch_size, n_joints, seq_len, self.n_heads * self.d_head).permute(0, 2, 1, 3))
spatial_attn_out = self.dropout(spatial_attn_out)
if self.debug:
return spatial_attn_out, mha_attn_spatial, attn_spatial
return spatial_attn_out


class PostNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn

def forward(self, x, **kwargs):
return self.norm(self.fn(x) + x, **kwargs)


class MLP(nn.Module):
def __init__(self, dim, embedding_dim, dropout=0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, embedding_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(embedding_dim, dim),
nn.Dropout(dropout)
)

def forward(self, x):
return self.net(x)


class Transformer(nn.Module):
def __init__(self, mlp_dim, dim, n_blocks, n_heads, d_head, dropout,
n_joints):
super().__init__()
self.blocks = nn.ModuleList([])
for _ in range(n_blocks):
self.blocks.append(
nn.ModuleList(
[PostNorm(dim, TemporalMultiHeadAttentionBlock(n_heads=n_heads,
d_head=d_head,
n_joints=n_joints,
dropout=dropout,
d_model=dim)),
PostNorm(dim, SpatialMultiHeadAttentionBlock(n_heads=n_heads,
d_head=d_head,
n_joints=n_joints,
d_model=dim,
dropout=dropout)),
PostNorm(dim, MLP(dim=dim, embedding_dim=mlp_dim, dropout=dropout))])
)

def forward(self, x):
for tmp_attn, spa_attn, mlp in self.blocks:
o_1 = tmp_attn(x)
o_2 = spa_attn(x)
out = mlp(o_1 + o_2)
return out


@MODELS.register_module()
class SpatioTemporalTransformer(SkelcastModule):
"""
PyTorch implementation of the model proposed in the paper:
"A Spatio-temporal Transformer for 3D Human Motion Prediction"
https://arxiv.org/abs/2004.08692
Args:
- n_joints `int`: Number of joints in the skeleton
- d_model `int`: The input dimensionality after the linear projection that computes the skeleton joints representation
- n_blocks `int`: Number of transformer blocks
- n_heads `int`: Number of self attention heads (for both temporal and spatial attention)
- d_head `int`: The per-head dimensionality
- mlp_dim `int`: The dimensionality of the MLP
- dropout `float`: Dropout probability
"""
def __init__(self, n_joints,
d_model,
n_blocks,
n_heads,
d_head,
mlp_dim,
dropout,
loss_fn: nn.Module = None,
debug=False):
super().__init__()
self.n_joints = n_joints
self.d_model = d_model
self.n_blocks = n_blocks
self.n_heads = n_heads
self.d_head = d_head
self.mlp_dim = mlp_dim
self.dropout = dropout
self.loss_fn = nn.SmoothL1Loss() if loss_fn is None else loss_fn
self.debug = debug

# Embedding projection before feeding into the transformer
self.embedding = nn.Linear(in_features=3 * n_joints, out_features=d_model * n_joints, bias=False)
self.pre_dropout = nn.Dropout(dropout)
self.pe = PositionalEncoding(d_model=d_model * n_joints)

self.transformer = Transformer(mlp_dim=mlp_dim,
dim=d_model,
n_blocks=n_blocks,
n_heads=n_heads,
d_head=d_head,
dropout=dropout,
n_joints=n_joints)

self.linear_out = nn.Linear(in_features=d_model, out_features=3, bias=False)

def forward(self, x: torch.Tensor):
batch_size, seq_len, n_joints, dims = x.shape
input_ = x.view(batch_size, seq_len, n_joints * dims)
o = self.embedding(input_)
print(f'o shape after embedding: {o.shape}')
o = self.pe.pe.repeat(batch_size, 1, 1)[:, :seq_len, :] + o
print(f'o shape after positional encoding: {o.shape}')
o = self.pre_dropout(o)
o = o.view(batch_size, seq_len, n_joints, self.d_model)
o = self.transformer(o)
print(f'o shape after transformer: {o.shape}')
out = self.linear_out(o) + x
return out

def training_step(self, **kwargs) -> dict:
# Retrieve the x and y from the keyword arguments
x, y = kwargs['x'], kwargs['y']
# Forward pass
out = self(x)
# Compute the loss
loss = self.loss_fn(out, y)
return {'loss': loss, 'out': out}

def validation_step(self, *args, **kwargs):
with torch.no_grad():
return self.training_step(*args, **kwargs)

def predict(self, *args, **kwargs):
pass

0 comments on commit be590f5

Please sign in to comment.