From 7300391ffda2c6fff2a10e570ce97a9e527ff5ba Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Wed, 31 Jan 2024 11:41:40 +0200 Subject: [PATCH 1/2] Define the input dimensionality in the config file. May be differenet if using quaternions/expmaps/xyz. --- configs/sttf_base.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/configs/sttf_base.yaml b/configs/sttf_base.yaml index d099aba..097748a 100644 --- a/configs/sttf_base.yaml +++ b/configs/sttf_base.yaml @@ -41,6 +41,7 @@ model: name: SpatioTemporalTransformer args: n_joints: 25 + input_dim: 3 d_model: 256 n_blocks: 3 n_heads: 8 From 5e446a0389b36b1321e17475124ffd5474c4db49 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Wed, 31 Jan 2024 11:42:20 +0200 Subject: [PATCH 2/2] Required parameter: Input dimensionality. Adapt to new input modalities --- src/skelcast/models/transformers/sttf.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/skelcast/models/transformers/sttf.py b/src/skelcast/models/transformers/sttf.py index afc56e5..9a8efac 100644 --- a/src/skelcast/models/transformers/sttf.py +++ b/src/skelcast/models/transformers/sttf.py @@ -163,6 +163,7 @@ class SpatioTemporalTransformer(SkelcastModule): - dropout `float`: Dropout probability """ def __init__(self, n_joints, + input_dim, d_model, n_blocks, n_heads, @@ -172,6 +173,7 @@ def __init__(self, n_joints, loss_fn: nn.Module = None, debug=False): super().__init__() + self.input_dim = input_dim self.n_joints = n_joints self.d_model = d_model self.n_blocks = n_blocks @@ -183,7 +185,7 @@ def __init__(self, n_joints, self.debug = debug # Embedding projection before feeding into the transformer - self.embedding = nn.Linear(in_features=3 * n_joints, out_features=d_model * n_joints, bias=False) + self.embedding = nn.Linear(in_features=input_dim * n_joints, out_features=d_model * n_joints, bias=False) self.pre_dropout = nn.Dropout(dropout) self.pe = PositionalEncoding(d_model=d_model * n_joints) @@ -195,7 +197,7 @@ def __init__(self, n_joints, dropout=dropout, n_joints=n_joints) - self.linear_out = nn.Linear(in_features=d_model, out_features=3, bias=False) + self.linear_out = nn.Linear(in_features=d_model, out_features=input_dim, bias=False) def forward(self, x: torch.Tensor): if x.ndim > 4: