Skip to content

Commit

Permalink
Merge pull request #50 from kaseris/fix/pvred-model-input
Browse files Browse the repository at this point in the history
Fix/pvred model input
  • Loading branch information
kaseris committed Dec 12, 2023
2 parents 4a48735 + 64b9e52 commit eb9f33a
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 17 deletions.
2 changes: 2 additions & 0 deletions src/skelcast/core/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def build_from_file(self, config_path: str) -> None:
checkpoint_dir=self.checkpoint_dir,
**cfgs.runner_config.get('args'))
logging.info(f'Finished building environment from {config_path}.')
self._runner.setup()
logging.info(f'Set up runner.')


def run(self) -> None:
Expand Down
9 changes: 5 additions & 4 deletions src/skelcast/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __call__(self, batch) -> NTURGBDSample:
batch_x = torch.nn.utils.rnn.pack_padded_sequence(batch_x, seq_lens, batch_first=True, enforce_sorted=False)
batch_y = torch.nn.utils.rnn.pack_padded_sequence(batch_y, seq_lens, batch_first=True, enforce_sorted=False)
labels = default_collate(labels)
return NTURGBDSample(x=batch_x, y=batch_y, label=labels)
return NTURGBDSample(x=batch_x, y=batch_y, label=labels, mask=None)

def get_windows(self, x):
seq_len = x.shape[0]
Expand Down Expand Up @@ -190,20 +190,21 @@ def __call__(self, batch) -> NTURGBDSample:
seq_lens = [sample.shape[0] for sample, _ in batch]
labels = [label for _, label in batch]
pre_batch = []
pre_mask = []
for sample, _ in batch:
logging.debug(f'sample.shape: {sample.shape}')
if sample.shape[0] <= self.block_size:
# Sample the entire sequence
logging.debug(f'Detected a sample with a sample length of {sample.shape[0]}')
pre_batch.append(sample)
pre_mask.append(torch.ones_like(sample))
else:
# Sample a random index
idx = torch.randint(low=0, high=sample.shape[0] - self.block_size, size=(1,)).item()
pre_batch.append(sample[idx:idx + self.block_size, ...])
pre_mask.append(torch.ones_like(sample[idx:idx + self.block_size, ...]))
# Pad the sequences to the maximum sequence length in the batch
batch_x = torch.nn.utils.rnn.pad_sequence(pre_batch, batch_first=True)
# Generate masks
masks = torch.nn.utils.rnn.pack_sequence([torch.ones(seq_len) for seq_len in seq_lens], enforce_sorted=False).to(torch.float32)
masks = torch.nn.utils.rnn.pad_sequence(pre_mask, batch_first=True)
return NTURGBDSample(x=batch_x, y=batch_x, label=labels, mask=masks)


Expand Down
17 changes: 9 additions & 8 deletions src/skelcast/experiments/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,13 @@ def fit(self):
return self._compile_results()

def training_step(self, train_batch: NTURGBDSample):
x, y = train_batch.x, train_batch.y
x, y, mask = train_batch.x, train_batch.y, train_batch.mask
# Cast them to a torch float32 and move them to the gpu
x, y = x.to(torch.float32), y.to(torch.float32)
x, y = x.to(self.device), y.to(self.device)
# TODO: Handle the mask None case
x, y, mask = x.to(torch.float32), y.to(torch.float32), mask.to(torch.float32)
x, y, mask = x.to(self.device), y.to(self.device), mask.to(self.device)
self.model.train()
out = self.model.training_step(x, y)
out = self.model.training_step(x, y, mask) # TODO: Make the other models accept a mask as well
loss = out['loss']
self.optimizer.zero_grad()
loss.backward()
Expand All @@ -212,12 +213,12 @@ def training_step(self, train_batch: NTURGBDSample):
self.logger.add_scalar(tag='train/step_loss', scalar_value=loss.item(), global_step=len(self.training_loss_per_step))

def validation_step(self, val_batch: NTURGBDSample):
x, y = val_batch.x, val_batch.y
x, y, mask = val_batch.x, val_batch.y, val_batch.mask
# Cast them to a torch float32 and move them to the gpu
x, y = x.to(torch.float32), y.to(torch.float32)
x, y = x.to(self.device), y.to(self.device)
x, y, mask = x.to(torch.float32), y.to(torch.float32), mask.to(torch.float32)
x, y, mask = x.to(self.device), y.to(self.device), mask.to(self.device)
self.model.eval()
out = self.model.validation_step(x, y)
out = self.model.validation_step(x, y, mask)
loss = out['loss']
self.validation_loss_per_step.append(loss.item())
# Log it to the logger
Expand Down
18 changes: 13 additions & 5 deletions src/skelcast/models/rnn/pvred.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

import torch
import torch.nn as nn

Expand Down Expand Up @@ -174,7 +176,13 @@ def __init__(self, input_dim: int, enc_hidden_dim: int = 64,
hidden_dim=dec_hidden_dim, batch_first=batch_first)


def forward(self, x: torch.Tensor, masks: torch.Tensor = None) -> torch.Tensor:
def forward(self, x: torch.Tensor, y:torch.Tensor, masks: torch.Tensor = None) -> torch.Tensor:
"""y is not used, it's only to satisfy the Runner's API
TODO: Remove y from the API, or find an adaptive way to infer the parameters"""

batch_size, seq_len, n_bodies, n_joints, dims = x.shape
x = x.view(batch_size, seq_len, n_bodies * n_joints * dims)
masks = masks.view(batch_size, seq_len, n_bodies * n_joints * dims)
# Calculate the velocity if the include_velocity flag is true
if self.include_velocity:
vel_inp = self._calculate_velocity(x)
Expand All @@ -197,11 +205,11 @@ def forward(self, x: torch.Tensor, masks: torch.Tensor = None) -> torch.Tensor:
assert dec_out.shape == targets.shape, f'dec_out.shape must be equal to targets.shape, got {dec_out.shape} and {targets.shape}'
# Apply the padded length masks to the prediction
if self.use_padded_len_mask:
dec_out = dec_out * masks.float()
dec_out = dec_out * masks[:, self.observe_until:, :]

# Apply the std masks to the prediction
if self.use_std_mask:
dec_out = dec_out * mask_pred.float()
dec_out = dec_out * mask_pred.to(torch.float32)

# Calculate the loss
loss = self.loss_fn(dec_out, targets)
Expand All @@ -225,11 +233,11 @@ def _calculate_velocity(self, x: torch.Tensor) -> torch.Tensor:
velocity[:, 1:, :] = x[:, 1:, :] - x[:, :-1, :]
return velocity

def training_step(self, x: torch.Tensor, y: torch.Tensor) -> dict:
def training_step(self, x: torch.Tensor, y: torch.Tensor, mask: torch.Tensor) -> dict:
self.encoder.train()
self.decoder.train()
# Forward pass
dec_out, loss = self(x, y)
dec_out, loss = self(x, y, mask)
return {'loss': loss, 'out': dec_out}

@torch.no_grad()
Expand Down

0 comments on commit eb9f33a

Please sign in to comment.