Skip to content

Commit

Permalink
Merge pull request #52 from kaseris/fix/pvred-model-input
Browse files Browse the repository at this point in the history
Fix/pvred model input
  • Loading branch information
kaseris committed Dec 14, 2023
2 parents 3dcfb33 + 59d5277 commit cdd1abe
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
8 changes: 4 additions & 4 deletions configs/pvred.yaml → configs/pvred_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ transforms:
- name: MinMaxScaleTransform
args:
feature_scale: [0.0, 1.0]
- name: CartToExpMapsTransform
args:
parents: null
# - name: CartToExpMapsTransform
# args:
# parents: null

loss:
name: MSELoss
Expand Down Expand Up @@ -48,7 +48,7 @@ model:
enc_type: lstm
dec_type: lstm
include_velocity: false
pos_enc: add
pos_enc: null
batch_first: true
std_thresh: 0.0001
use_std_mask: false
Expand Down
3 changes: 2 additions & 1 deletion src/skelcast/models/rnn/pvred.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ def forward(self, x: torch.Tensor, y:torch.Tensor, masks: torch.Tensor = None) -

batch_size, seq_len, n_bodies, n_joints, dims = x.shape
x = x.view(batch_size, seq_len, n_bodies * n_joints * dims)
masks = masks.view(batch_size, seq_len, n_bodies * n_joints * dims)
if masks is not None:
masks = masks.view(batch_size, seq_len, n_bodies * n_joints * dims)
# Calculate the velocity if the include_velocity flag is true
if self.include_velocity:
vel_inp = self._calculate_velocity(x)
Expand Down

0 comments on commit cdd1abe

Please sign in to comment.