-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #59 from kaseris/models/sttf
Models/sttf
- Loading branch information
Showing
3 changed files
with
291 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
dataset: | ||
name: NTURGBDDataset | ||
args: | ||
data_directory: /home/kaseris/Documents/mount/data_ntu_rgbd | ||
label_file: /home/kaseris/Documents/dev/skelcast/data/labels.txt | ||
missing_files_dir: /home/kaseris/Documents/dev/skelcast/data/missing | ||
max_context_window: 10 | ||
max_number_of_bodies: 1 | ||
max_duration: 300 | ||
n_joints: 25 | ||
cache_file: /home/kaseris/Documents/mount/dataset_cache.pkl | ||
|
||
transforms: | ||
- name: MinMaxScaleTransform | ||
args: | ||
feature_scale: [0.0, 1.0] | ||
- name: CartToQuaternionTransform | ||
args: | ||
parents: null | ||
|
||
loss: | ||
name: SmoothL1Loss | ||
args: | ||
reduction: mean | ||
beta: 0.01 | ||
|
||
collate_fn: | ||
name: NTURGBDCollateFnWithRandomSampledContextWindow | ||
args: | ||
block_size: 10 | ||
|
||
logger: | ||
name: TensorboardLogger | ||
args: | ||
log_dir: runs | ||
|
||
optimizer: | ||
name: AdamW | ||
args: | ||
lr: 0.0001 | ||
weight_decay: 0.0001 | ||
|
||
model: | ||
name: SpatioTemporalTransformer | ||
args: | ||
n_joints: 25 | ||
d_model: 256 | ||
n_blocks: 8 | ||
n_heads: 8 | ||
d_head: 16 | ||
mlp_dim: 512 | ||
dropout: 0.5 | ||
|
||
runner: | ||
name: Runner | ||
args: | ||
train_batch_size: 1024 | ||
val_batch_size: 1024 | ||
block_size: 8 | ||
log_gradient_info: true | ||
device: cuda | ||
n_epochs: 100 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from skelcast.models.module import SkelcastModule | ||
from skelcast.models.transformers.base import PositionalEncoding | ||
from skelcast.models import MODELS | ||
|
||
|
||
class TemporalMultiHeadAttentionBlock(nn.Module): | ||
''' | ||
Args: | ||
n_heads `int`: Number of heads | ||
d_model `int`: The input dimensionality | ||
d_head `int`: The per-head dimensionality | ||
''' | ||
def __init__(self, n_heads, d_head, n_joints, d_model, dropout=0.1, debug=False): | ||
super().__init__() | ||
self.n_heads = n_heads | ||
self.d_model = d_model | ||
self.d_head = d_head | ||
self.n_joints = n_joints | ||
self.debug = debug | ||
|
||
self.q = nn.Linear(in_features=d_model * n_joints, out_features=n_joints * d_head * n_heads, bias=False) | ||
self.k = nn.Linear(in_features=d_model * n_joints, out_features=n_joints * d_head * n_heads, bias=False) | ||
self.v = nn.Linear(in_features=d_model * n_joints, out_features=n_joints * d_head * n_heads, bias=False) | ||
|
||
self.back_proj = nn.Linear(in_features=d_head * n_heads, out_features=d_model, bias=False) # Project back to original dimensionality. | ||
self.dropout = nn.Dropout(dropout) | ||
|
||
def forward(self, x: torch.Tensor): | ||
batch_size, seq_len, n_joints, dims = x.shape | ||
assert n_joints == self.n_joints, f'Expected n_joints: {self.n_joints}. Got: n_joints: {n_joints}' | ||
assert dims == self.d_model, f'Expected d_model: {self.d_model}. Got: d_model: {dims}' | ||
x = x.view(batch_size, seq_len, self.n_joints * self.d_model) | ||
q_proj = self.q(x) | ||
k_proj = self.k(x) | ||
v_proj = self.v(x) | ||
mask = self.get_mask(seq_len, batch_size) | ||
attn_prod_ = torch.bmm(q_proj, k_proj.permute(0, 2, 1)) * (self.d_model) ** -0.5 | ||
|
||
attn_temporal = F.softmax(attn_prod_ + mask, dim=-1) | ||
attn = attn_temporal @ v_proj | ||
out = self.back_proj(attn.view(batch_size, seq_len, n_joints, -1)) | ||
out = self.dropout(out) | ||
if self.debug: | ||
return out.view(batch_size, seq_len, n_joints, dims), attn, attn_temporal | ||
return out.view(batch_size, seq_len, n_joints, dims) | ||
|
||
def get_mask(self, seq_len, batch_size): | ||
mask = torch.triu(torch.ones((seq_len, seq_len)) * float('-inf'), diagonal=1) | ||
return mask.repeat(batch_size, 1, 1) | ||
|
||
|
||
class SpatialMultiHeadAttentionBlock(nn.Module): | ||
def __init__(self, n_heads, d_head, n_joints, d_model, dropout=0.1, debug=False): | ||
super().__init__() | ||
self.n_heads = n_heads | ||
self.d_model = d_model | ||
self.d_head = d_head | ||
self.n_joints = n_joints | ||
self.debug = debug | ||
|
||
# The query projection treats all joints differently, so stays as is | ||
self.q_spatial = nn.Linear(in_features=n_joints * d_model, out_features=n_joints * n_heads * d_head, bias=False) | ||
# The key and value projections are shared across all joints and time steps | ||
self.k_spatial = nn.Linear(in_features=d_model, out_features=n_heads * d_head, bias=False) | ||
self.v_spatial = nn.Linear(in_features=d_model, out_features=n_heads * d_head, bias=False) | ||
self.back_proj_spatial = nn.Linear(in_features=n_heads * d_head, out_features=d_model, bias=False) | ||
self.dropout = nn.Dropout(dropout) | ||
|
||
def forward(self, x): | ||
batch_size, seq_len, n_joints, dims = x.shape | ||
assert n_joints == self.n_joints, f'Expected n_joints: {self.n_joints}. Got: n_joints: {n_joints}' | ||
assert dims == self.d_model, f'Expected d_model: {self.d_model}. Got: d_model: {dims}' | ||
# Treat all joints individually to compute the query | ||
q_proj_spatial = self.q_spatial(x.view(batch_size, seq_len, n_joints * self.d_model)) | ||
q_proj_spatial = q_proj_spatial.view(batch_size, n_joints, self.n_heads * self.d_head * seq_len) | ||
k_proj_spatial = self.k_spatial(x).view(batch_size, n_joints, self.n_heads * self.d_head * seq_len) | ||
v_proj_spatial = self.v_spatial(x).view(batch_size, n_joints, self.n_heads * self.d_head * seq_len) | ||
|
||
attn_prod_ = torch.bmm(q_proj_spatial, k_proj_spatial.permute(0, 2, 1)) * (self.d_model) ** -.5 | ||
|
||
attn_spatial = F.softmax(attn_prod_, dim=-1) | ||
|
||
mha_attn_spatial = attn_spatial @ v_proj_spatial | ||
spatial_attn_out = self.back_proj_spatial(mha_attn_spatial.view(batch_size, n_joints, seq_len, self.n_heads * self.d_head).permute(0, 2, 1, 3)) | ||
spatial_attn_out = self.dropout(spatial_attn_out) | ||
if self.debug: | ||
return spatial_attn_out, mha_attn_spatial, attn_spatial | ||
return spatial_attn_out | ||
|
||
|
||
class PostNorm(nn.Module): | ||
def __init__(self, dim, fn): | ||
super().__init__() | ||
self.norm = nn.LayerNorm(dim) | ||
self.fn = fn | ||
|
||
def forward(self, x, **kwargs): | ||
return self.norm(self.fn(x) + x, **kwargs) | ||
|
||
|
||
class MLP(nn.Module): | ||
def __init__(self, dim, embedding_dim, dropout=0.1): | ||
super().__init__() | ||
self.net = nn.Sequential( | ||
nn.Linear(dim, embedding_dim), | ||
nn.GELU(), | ||
nn.Dropout(dropout), | ||
nn.Linear(embedding_dim, dim), | ||
nn.Dropout(dropout) | ||
) | ||
|
||
def forward(self, x): | ||
return self.net(x) | ||
|
||
|
||
class Transformer(nn.Module): | ||
def __init__(self, mlp_dim, dim, n_blocks, n_heads, d_head, dropout, | ||
n_joints): | ||
super().__init__() | ||
self.blocks = nn.ModuleList([]) | ||
for _ in range(n_blocks): | ||
self.blocks.append( | ||
nn.ModuleList( | ||
[PostNorm(dim, TemporalMultiHeadAttentionBlock(n_heads=n_heads, | ||
d_head=d_head, | ||
n_joints=n_joints, | ||
dropout=dropout, | ||
d_model=dim)), | ||
PostNorm(dim, SpatialMultiHeadAttentionBlock(n_heads=n_heads, | ||
d_head=d_head, | ||
n_joints=n_joints, | ||
d_model=dim, | ||
dropout=dropout)), | ||
PostNorm(dim, MLP(dim=dim, embedding_dim=mlp_dim, dropout=dropout))]) | ||
) | ||
|
||
def forward(self, x): | ||
for tmp_attn, spa_attn, mlp in self.blocks: | ||
o_1 = tmp_attn(x) | ||
o_2 = spa_attn(x) | ||
out = mlp(o_1 + o_2) | ||
return out | ||
|
||
|
||
@MODELS.register_module() | ||
class SpatioTemporalTransformer(SkelcastModule): | ||
""" | ||
PyTorch implementation of the model proposed in the paper: | ||
"A Spatio-temporal Transformer for 3D Human Motion Prediction" | ||
https://arxiv.org/abs/2004.08692 | ||
Args: | ||
- n_joints `int`: Number of joints in the skeleton | ||
- d_model `int`: The input dimensionality after the linear projection that computes the skeleton joints representation | ||
- n_blocks `int`: Number of transformer blocks | ||
- n_heads `int`: Number of self attention heads (for both temporal and spatial attention) | ||
- d_head `int`: The per-head dimensionality | ||
- mlp_dim `int`: The dimensionality of the MLP | ||
- dropout `float`: Dropout probability | ||
""" | ||
def __init__(self, n_joints, | ||
d_model, | ||
n_blocks, | ||
n_heads, | ||
d_head, | ||
mlp_dim, | ||
dropout, | ||
loss_fn: nn.Module = None, | ||
debug=False): | ||
super().__init__() | ||
self.n_joints = n_joints | ||
self.d_model = d_model | ||
self.n_blocks = n_blocks | ||
self.n_heads = n_heads | ||
self.d_head = d_head | ||
self.mlp_dim = mlp_dim | ||
self.dropout = dropout | ||
self.loss_fn = nn.SmoothL1Loss() if loss_fn is None else loss_fn | ||
self.debug = debug | ||
|
||
# Embedding projection before feeding into the transformer | ||
self.embedding = nn.Linear(in_features=3 * n_joints, out_features=d_model * n_joints, bias=False) | ||
self.pre_dropout = nn.Dropout(dropout) | ||
self.pe = PositionalEncoding(d_model=d_model * n_joints) | ||
|
||
self.transformer = Transformer(mlp_dim=mlp_dim, | ||
dim=d_model, | ||
n_blocks=n_blocks, | ||
n_heads=n_heads, | ||
d_head=d_head, | ||
dropout=dropout, | ||
n_joints=n_joints) | ||
|
||
self.linear_out = nn.Linear(in_features=d_model, out_features=3, bias=False) | ||
|
||
def forward(self, x: torch.Tensor): | ||
batch_size, seq_len, n_joints, dims = x.shape | ||
input_ = x.view(batch_size, seq_len, n_joints * dims) | ||
o = self.embedding(input_) | ||
print(f'o shape after embedding: {o.shape}') | ||
o = self.pe.pe.repeat(batch_size, 1, 1)[:, :seq_len, :] + o | ||
print(f'o shape after positional encoding: {o.shape}') | ||
o = self.pre_dropout(o) | ||
o = o.view(batch_size, seq_len, n_joints, self.d_model) | ||
o = self.transformer(o) | ||
print(f'o shape after transformer: {o.shape}') | ||
out = self.linear_out(o) + x | ||
return out | ||
|
||
def training_step(self, **kwargs) -> dict: | ||
# Retrieve the x and y from the keyword arguments | ||
x, y = kwargs['x'], kwargs['y'] | ||
# Forward pass | ||
out = self(x) | ||
# Compute the loss | ||
loss = self.loss_fn(out, y) | ||
return {'loss': loss, 'out': out} | ||
|
||
def validation_step(self, *args, **kwargs): | ||
with torch.no_grad(): | ||
return self.training_step(*args, **kwargs) | ||
|
||
def predict(self, *args, **kwargs): | ||
pass |