From 2b2013e6e99bbb8e2252290cf1f88ca297dd6033 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Tue, 12 Dec 2023 13:23:12 +0200 Subject: [PATCH 1/6] Use the mask from the batch sample --- src/skelcast/experiments/runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/skelcast/experiments/runner.py b/src/skelcast/experiments/runner.py index a70c1a9..9db2087 100644 --- a/src/skelcast/experiments/runner.py +++ b/src/skelcast/experiments/runner.py @@ -180,7 +180,7 @@ def fit(self): return self._compile_results() def training_step(self, train_batch: NTURGBDSample): - x, y = train_batch.x, train_batch.y + x, y, mask = train_batch.x, train_batch.y, train_batch.mask # Cast them to a torch float32 and move them to the gpu x, y = x.to(torch.float32), y.to(torch.float32) x, y = x.to(self.device), y.to(self.device) @@ -212,7 +212,7 @@ def training_step(self, train_batch: NTURGBDSample): self.logger.add_scalar(tag='train/step_loss', scalar_value=loss.item(), global_step=len(self.training_loss_per_step)) def validation_step(self, val_batch: NTURGBDSample): - x, y = val_batch.x, val_batch.y + x, y = val_batch.x, val_batch.y, val_batch.mask # Cast them to a torch float32 and move them to the gpu x, y = x.to(torch.float32), y.to(torch.float32) x, y = x.to(self.device), y.to(self.device) From dcea08f7233292b8567d50d381c2b8c7878e6ba4 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Tue, 12 Dec 2023 13:24:05 +0200 Subject: [PATCH 2/6] Initialize the default NTURGBDCollateFn returned sample mask with None --- src/skelcast/data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/skelcast/data/dataset.py b/src/skelcast/data/dataset.py index a50f34b..a8e664d 100644 --- a/src/skelcast/data/dataset.py +++ b/src/skelcast/data/dataset.py @@ -147,7 +147,7 @@ def __call__(self, batch) -> NTURGBDSample: batch_x = torch.nn.utils.rnn.pack_padded_sequence(batch_x, seq_lens, batch_first=True, enforce_sorted=False) batch_y = torch.nn.utils.rnn.pack_padded_sequence(batch_y, seq_lens, batch_first=True, enforce_sorted=False) labels = default_collate(labels) - return NTURGBDSample(x=batch_x, y=batch_y, label=labels) + return NTURGBDSample(x=batch_x, y=batch_y, label=labels, mask=None) def get_windows(self, x): seq_len = x.shape[0] From eb52ea301a228098c872582aa5a862dc5b8acc74 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Tue, 12 Dec 2023 13:37:05 +0200 Subject: [PATCH 3/6] Setup runner --- src/skelcast/core/environment.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/skelcast/core/environment.py b/src/skelcast/core/environment.py index 3880a08..1f08fe6 100644 --- a/src/skelcast/core/environment.py +++ b/src/skelcast/core/environment.py @@ -115,6 +115,8 @@ def build_from_file(self, config_path: str) -> None: checkpoint_dir=self.checkpoint_dir, **cfgs.runner_config.get('args')) logging.info(f'Finished building environment from {config_path}.') + self._runner.setup() + logging.info(f'Set up runner.') def run(self) -> None: From 479a7f4423d2830d5776731d17d037924f44b0d6 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Tue, 12 Dec 2023 13:58:50 +0200 Subject: [PATCH 4/6] Fix the predictions mask creation --- src/skelcast/data/dataset.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/skelcast/data/dataset.py b/src/skelcast/data/dataset.py index a8e664d..452cf25 100644 --- a/src/skelcast/data/dataset.py +++ b/src/skelcast/data/dataset.py @@ -190,20 +190,21 @@ def __call__(self, batch) -> NTURGBDSample: seq_lens = [sample.shape[0] for sample, _ in batch] labels = [label for _, label in batch] pre_batch = [] + pre_mask = [] for sample, _ in batch: - logging.debug(f'sample.shape: {sample.shape}') if sample.shape[0] <= self.block_size: # Sample the entire sequence - logging.debug(f'Detected a sample with a sample length of {sample.shape[0]}') pre_batch.append(sample) + pre_mask.append(torch.ones_like(sample)) else: # Sample a random index idx = torch.randint(low=0, high=sample.shape[0] - self.block_size, size=(1,)).item() pre_batch.append(sample[idx:idx + self.block_size, ...]) + pre_mask.append(torch.ones_like(sample[idx:idx + self.block_size, ...])) # Pad the sequences to the maximum sequence length in the batch batch_x = torch.nn.utils.rnn.pad_sequence(pre_batch, batch_first=True) # Generate masks - masks = torch.nn.utils.rnn.pack_sequence([torch.ones(seq_len) for seq_len in seq_lens], enforce_sorted=False).to(torch.float32) + masks = torch.nn.utils.rnn.pad_sequence(pre_mask, batch_first=True) return NTURGBDSample(x=batch_x, y=batch_x, label=labels, mask=masks) From 696e2c16e8c51563bdbd6ef32d9a6b8952279a13 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Tue, 12 Dec 2023 14:00:02 +0200 Subject: [PATCH 5/6] Include masks to the training --- src/skelcast/experiments/runner.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/skelcast/experiments/runner.py b/src/skelcast/experiments/runner.py index 9db2087..7a00c6d 100644 --- a/src/skelcast/experiments/runner.py +++ b/src/skelcast/experiments/runner.py @@ -182,10 +182,11 @@ def fit(self): def training_step(self, train_batch: NTURGBDSample): x, y, mask = train_batch.x, train_batch.y, train_batch.mask # Cast them to a torch float32 and move them to the gpu - x, y = x.to(torch.float32), y.to(torch.float32) - x, y = x.to(self.device), y.to(self.device) + # TODO: Handle the mask None case + x, y, mask = x.to(torch.float32), y.to(torch.float32), mask.to(torch.float32) + x, y, mask = x.to(self.device), y.to(self.device), mask.to(self.device) self.model.train() - out = self.model.training_step(x, y) + out = self.model.training_step(x, y, mask) # TODO: Make the other models accept a mask as well loss = out['loss'] self.optimizer.zero_grad() loss.backward() @@ -212,12 +213,12 @@ def training_step(self, train_batch: NTURGBDSample): self.logger.add_scalar(tag='train/step_loss', scalar_value=loss.item(), global_step=len(self.training_loss_per_step)) def validation_step(self, val_batch: NTURGBDSample): - x, y = val_batch.x, val_batch.y, val_batch.mask + x, y, mask = val_batch.x, val_batch.y, val_batch.mask # Cast them to a torch float32 and move them to the gpu - x, y = x.to(torch.float32), y.to(torch.float32) - x, y = x.to(self.device), y.to(self.device) + x, y, mask = x.to(torch.float32), y.to(torch.float32), mask.to(torch.float32) + x, y, mask = x.to(self.device), y.to(self.device), mask.to(self.device) self.model.eval() - out = self.model.validation_step(x, y) + out = self.model.validation_step(x, y, mask) loss = out['loss'] self.validation_loss_per_step.append(loss.item()) # Log it to the logger From 64b9e5272c747b547ecbcc8b4db47ab5f2d33b3b Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Tue, 12 Dec 2023 14:00:29 +0200 Subject: [PATCH 6/6] Correct handling of the masks --- src/skelcast/models/rnn/pvred.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/skelcast/models/rnn/pvred.py b/src/skelcast/models/rnn/pvred.py index dd2184a..1ae1bd1 100644 --- a/src/skelcast/models/rnn/pvred.py +++ b/src/skelcast/models/rnn/pvred.py @@ -1,3 +1,5 @@ +import logging + import torch import torch.nn as nn @@ -174,7 +176,13 @@ def __init__(self, input_dim: int, enc_hidden_dim: int = 64, hidden_dim=dec_hidden_dim, batch_first=batch_first) - def forward(self, x: torch.Tensor, masks: torch.Tensor = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, y:torch.Tensor, masks: torch.Tensor = None) -> torch.Tensor: + """y is not used, it's only to satisfy the Runner's API + TODO: Remove y from the API, or find an adaptive way to infer the parameters""" + + batch_size, seq_len, n_bodies, n_joints, dims = x.shape + x = x.view(batch_size, seq_len, n_bodies * n_joints * dims) + masks = masks.view(batch_size, seq_len, n_bodies * n_joints * dims) # Calculate the velocity if the include_velocity flag is true if self.include_velocity: vel_inp = self._calculate_velocity(x) @@ -197,11 +205,11 @@ def forward(self, x: torch.Tensor, masks: torch.Tensor = None) -> torch.Tensor: assert dec_out.shape == targets.shape, f'dec_out.shape must be equal to targets.shape, got {dec_out.shape} and {targets.shape}' # Apply the padded length masks to the prediction if self.use_padded_len_mask: - dec_out = dec_out * masks.float() + dec_out = dec_out * masks[:, self.observe_until:, :] # Apply the std masks to the prediction if self.use_std_mask: - dec_out = dec_out * mask_pred.float() + dec_out = dec_out * mask_pred.to(torch.float32) # Calculate the loss loss = self.loss_fn(dec_out, targets) @@ -225,11 +233,11 @@ def _calculate_velocity(self, x: torch.Tensor) -> torch.Tensor: velocity[:, 1:, :] = x[:, 1:, :] - x[:, :-1, :] return velocity - def training_step(self, x: torch.Tensor, y: torch.Tensor) -> dict: + def training_step(self, x: torch.Tensor, y: torch.Tensor, mask: torch.Tensor) -> dict: self.encoder.train() self.decoder.train() # Forward pass - dec_out, loss = self(x, y) + dec_out, loss = self(x, y, mask) return {'loss': loss, 'out': dec_out} @torch.no_grad()