diff --git a/configs/sttf_base.yaml b/configs/sttf_base.yaml index d099aba..097748a 100644 --- a/configs/sttf_base.yaml +++ b/configs/sttf_base.yaml @@ -41,6 +41,7 @@ model: name: SpatioTemporalTransformer args: n_joints: 25 + input_dim: 3 d_model: 256 n_blocks: 3 n_heads: 8 diff --git a/src/skelcast/models/transformers/sttf.py b/src/skelcast/models/transformers/sttf.py index afc56e5..9a8efac 100644 --- a/src/skelcast/models/transformers/sttf.py +++ b/src/skelcast/models/transformers/sttf.py @@ -163,6 +163,7 @@ class SpatioTemporalTransformer(SkelcastModule): - dropout `float`: Dropout probability """ def __init__(self, n_joints, + input_dim, d_model, n_blocks, n_heads, @@ -172,6 +173,7 @@ def __init__(self, n_joints, loss_fn: nn.Module = None, debug=False): super().__init__() + self.input_dim = input_dim self.n_joints = n_joints self.d_model = d_model self.n_blocks = n_blocks @@ -183,7 +185,7 @@ def __init__(self, n_joints, self.debug = debug # Embedding projection before feeding into the transformer - self.embedding = nn.Linear(in_features=3 * n_joints, out_features=d_model * n_joints, bias=False) + self.embedding = nn.Linear(in_features=input_dim * n_joints, out_features=d_model * n_joints, bias=False) self.pre_dropout = nn.Dropout(dropout) self.pe = PositionalEncoding(d_model=d_model * n_joints) @@ -195,7 +197,7 @@ def __init__(self, n_joints, dropout=dropout, n_joints=n_joints) - self.linear_out = nn.Linear(in_features=d_model, out_features=3, bias=False) + self.linear_out = nn.Linear(in_features=d_model, out_features=input_dim, bias=False) def forward(self, x: torch.Tensor): if x.ndim > 4: