Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/pvred model input #50

Merged
merged 6 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/skelcast/core/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def build_from_file(self, config_path: str) -> None:
checkpoint_dir=self.checkpoint_dir,
**cfgs.runner_config.get('args'))
logging.info(f'Finished building environment from {config_path}.')
self._runner.setup()
logging.info(f'Set up runner.')


def run(self) -> None:
Expand Down
9 changes: 5 additions & 4 deletions src/skelcast/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __call__(self, batch) -> NTURGBDSample:
batch_x = torch.nn.utils.rnn.pack_padded_sequence(batch_x, seq_lens, batch_first=True, enforce_sorted=False)
batch_y = torch.nn.utils.rnn.pack_padded_sequence(batch_y, seq_lens, batch_first=True, enforce_sorted=False)
labels = default_collate(labels)
return NTURGBDSample(x=batch_x, y=batch_y, label=labels)
return NTURGBDSample(x=batch_x, y=batch_y, label=labels, mask=None)

def get_windows(self, x):
seq_len = x.shape[0]
Expand Down Expand Up @@ -190,20 +190,21 @@ def __call__(self, batch) -> NTURGBDSample:
seq_lens = [sample.shape[0] for sample, _ in batch]
labels = [label for _, label in batch]
pre_batch = []
pre_mask = []
for sample, _ in batch:
logging.debug(f'sample.shape: {sample.shape}')
if sample.shape[0] <= self.block_size:
# Sample the entire sequence
logging.debug(f'Detected a sample with a sample length of {sample.shape[0]}')
pre_batch.append(sample)
pre_mask.append(torch.ones_like(sample))
else:
# Sample a random index
idx = torch.randint(low=0, high=sample.shape[0] - self.block_size, size=(1,)).item()
pre_batch.append(sample[idx:idx + self.block_size, ...])
pre_mask.append(torch.ones_like(sample[idx:idx + self.block_size, ...]))
# Pad the sequences to the maximum sequence length in the batch
batch_x = torch.nn.utils.rnn.pad_sequence(pre_batch, batch_first=True)
# Generate masks
masks = torch.nn.utils.rnn.pack_sequence([torch.ones(seq_len) for seq_len in seq_lens], enforce_sorted=False).to(torch.float32)
masks = torch.nn.utils.rnn.pad_sequence(pre_mask, batch_first=True)
return NTURGBDSample(x=batch_x, y=batch_x, label=labels, mask=masks)


Expand Down
17 changes: 9 additions & 8 deletions src/skelcast/experiments/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,13 @@ def fit(self):
return self._compile_results()

def training_step(self, train_batch: NTURGBDSample):
x, y = train_batch.x, train_batch.y
x, y, mask = train_batch.x, train_batch.y, train_batch.mask
# Cast them to a torch float32 and move them to the gpu
x, y = x.to(torch.float32), y.to(torch.float32)
x, y = x.to(self.device), y.to(self.device)
# TODO: Handle the mask None case
x, y, mask = x.to(torch.float32), y.to(torch.float32), mask.to(torch.float32)
x, y, mask = x.to(self.device), y.to(self.device), mask.to(self.device)
self.model.train()
out = self.model.training_step(x, y)
out = self.model.training_step(x, y, mask) # TODO: Make the other models accept a mask as well
loss = out['loss']
self.optimizer.zero_grad()
loss.backward()
Expand All @@ -212,12 +213,12 @@ def training_step(self, train_batch: NTURGBDSample):
self.logger.add_scalar(tag='train/step_loss', scalar_value=loss.item(), global_step=len(self.training_loss_per_step))

def validation_step(self, val_batch: NTURGBDSample):
x, y = val_batch.x, val_batch.y
x, y, mask = val_batch.x, val_batch.y, val_batch.mask
# Cast them to a torch float32 and move them to the gpu
x, y = x.to(torch.float32), y.to(torch.float32)
x, y = x.to(self.device), y.to(self.device)
x, y, mask = x.to(torch.float32), y.to(torch.float32), mask.to(torch.float32)
x, y, mask = x.to(self.device), y.to(self.device), mask.to(self.device)
self.model.eval()
out = self.model.validation_step(x, y)
out = self.model.validation_step(x, y, mask)
loss = out['loss']
self.validation_loss_per_step.append(loss.item())
# Log it to the logger
Expand Down
18 changes: 13 additions & 5 deletions src/skelcast/models/rnn/pvred.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

import torch
import torch.nn as nn

Expand Down Expand Up @@ -174,7 +176,13 @@ def __init__(self, input_dim: int, enc_hidden_dim: int = 64,
hidden_dim=dec_hidden_dim, batch_first=batch_first)


def forward(self, x: torch.Tensor, masks: torch.Tensor = None) -> torch.Tensor:
def forward(self, x: torch.Tensor, y:torch.Tensor, masks: torch.Tensor = None) -> torch.Tensor:
"""y is not used, it's only to satisfy the Runner's API
TODO: Remove y from the API, or find an adaptive way to infer the parameters"""

batch_size, seq_len, n_bodies, n_joints, dims = x.shape
x = x.view(batch_size, seq_len, n_bodies * n_joints * dims)
masks = masks.view(batch_size, seq_len, n_bodies * n_joints * dims)
# Calculate the velocity if the include_velocity flag is true
if self.include_velocity:
vel_inp = self._calculate_velocity(x)
Expand All @@ -197,11 +205,11 @@ def forward(self, x: torch.Tensor, masks: torch.Tensor = None) -> torch.Tensor:
assert dec_out.shape == targets.shape, f'dec_out.shape must be equal to targets.shape, got {dec_out.shape} and {targets.shape}'
# Apply the padded length masks to the prediction
if self.use_padded_len_mask:
dec_out = dec_out * masks.float()
dec_out = dec_out * masks[:, self.observe_until:, :]

# Apply the std masks to the prediction
if self.use_std_mask:
dec_out = dec_out * mask_pred.float()
dec_out = dec_out * mask_pred.to(torch.float32)

# Calculate the loss
loss = self.loss_fn(dec_out, targets)
Expand All @@ -225,11 +233,11 @@ def _calculate_velocity(self, x: torch.Tensor) -> torch.Tensor:
velocity[:, 1:, :] = x[:, 1:, :] - x[:, :-1, :]
return velocity

def training_step(self, x: torch.Tensor, y: torch.Tensor) -> dict:
def training_step(self, x: torch.Tensor, y: torch.Tensor, mask: torch.Tensor) -> dict:
self.encoder.train()
self.decoder.train()
# Forward pass
dec_out, loss = self(x, y)
dec_out, loss = self(x, y, mask)
return {'loss': loss, 'out': dec_out}

@torch.no_grad()
Expand Down
Loading