Skip to content

Commit

Permalink
Correct handling of the masks
Browse files Browse the repository at this point in the history
  • Loading branch information
kaseris committed Dec 12, 2023
1 parent 696e2c1 commit 64b9e52
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions src/skelcast/models/rnn/pvred.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

import torch
import torch.nn as nn

Expand Down Expand Up @@ -174,7 +176,13 @@ def __init__(self, input_dim: int, enc_hidden_dim: int = 64,
hidden_dim=dec_hidden_dim, batch_first=batch_first)


def forward(self, x: torch.Tensor, masks: torch.Tensor = None) -> torch.Tensor:
def forward(self, x: torch.Tensor, y:torch.Tensor, masks: torch.Tensor = None) -> torch.Tensor:
"""y is not used, it's only to satisfy the Runner's API
TODO: Remove y from the API, or find an adaptive way to infer the parameters"""

batch_size, seq_len, n_bodies, n_joints, dims = x.shape
x = x.view(batch_size, seq_len, n_bodies * n_joints * dims)
masks = masks.view(batch_size, seq_len, n_bodies * n_joints * dims)
# Calculate the velocity if the include_velocity flag is true
if self.include_velocity:
vel_inp = self._calculate_velocity(x)
Expand All @@ -197,11 +205,11 @@ def forward(self, x: torch.Tensor, masks: torch.Tensor = None) -> torch.Tensor:
assert dec_out.shape == targets.shape, f'dec_out.shape must be equal to targets.shape, got {dec_out.shape} and {targets.shape}'
# Apply the padded length masks to the prediction
if self.use_padded_len_mask:
dec_out = dec_out * masks.float()
dec_out = dec_out * masks[:, self.observe_until:, :]

# Apply the std masks to the prediction
if self.use_std_mask:
dec_out = dec_out * mask_pred.float()
dec_out = dec_out * mask_pred.to(torch.float32)

# Calculate the loss
loss = self.loss_fn(dec_out, targets)
Expand All @@ -225,11 +233,11 @@ def _calculate_velocity(self, x: torch.Tensor) -> torch.Tensor:
velocity[:, 1:, :] = x[:, 1:, :] - x[:, :-1, :]
return velocity

def training_step(self, x: torch.Tensor, y: torch.Tensor) -> dict:
def training_step(self, x: torch.Tensor, y: torch.Tensor, mask: torch.Tensor) -> dict:
self.encoder.train()
self.decoder.train()
# Forward pass
dec_out, loss = self(x, y)
dec_out, loss = self(x, y, mask)
return {'loss': loss, 'out': dec_out}

@torch.no_grad()
Expand Down

0 comments on commit 64b9e52

Please sign in to comment.