From 1c3ccd0758afa479d9df67286c0570f8c594f7fe Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Thu, 7 Dec 2023 12:01:03 +0200 Subject: [PATCH 1/9] Dim fix for positional encoding --- src/skelcast/models/transformers/base.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/skelcast/models/transformers/base.py b/src/skelcast/models/transformers/base.py index 1a38f46..6315ecc 100644 --- a/src/skelcast/models/transformers/base.py +++ b/src/skelcast/models/transformers/base.py @@ -65,11 +65,22 @@ def __init__(self, d_model, max_len=5000): # Create a long enough 'PE' matrix with position and dimension indexes pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + + # Adjust div_term calculation div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) - + + # Assign sine to even and cosine to odd indices pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) + # Adjust cosine assignment for odd d_model + if d_model % 2 == 0: + pe[:, 1::2] = torch.cos(position * div_term) + else: + # For odd d_model, the last cosine term needs to be calculated differently + pe[:, 1::2] = torch.cos(position * div_term[:-1]) + last_cos = torch.cos(position * div_term[-1]) + pe[:, -1] = last_cos.squeeze() + pe = pe.unsqueeze(0).transpose(0, 1) # Registers pe as a buffer that should not be considered a model parameter. self.register_buffer('pe', pe) From 034a53cafded09ecd429092392c98b0c00e3ef54 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Thu, 7 Dec 2023 12:01:32 +0200 Subject: [PATCH 2/9] New model registration --- src/skelcast/models/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/skelcast/models/__init__.py b/src/skelcast/models/__init__.py index a1bdf3a..25ab7af 100644 --- a/src/skelcast/models/__init__.py +++ b/src/skelcast/models/__init__.py @@ -4,4 +4,5 @@ MODELS = Registry() from .rnn.lstm import SimpleLSTMRegressor -from .transformers.transformer import ForecastTransformer \ No newline at end of file +from .transformers.transformer import ForecastTransformer +from .rnn.pvred import PositionalVelocityRecurrentEncoderDecoder \ No newline at end of file From d213859f674b44b40f9fe98ed729a00cb3c2e4b8 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Thu, 7 Dec 2023 14:11:50 +0200 Subject: [PATCH 3/9] PVRED Implementation --- src/skelcast/models/rnn/pvred.py | 240 +++++++++++++++++++++++++++++++ 1 file changed, 240 insertions(+) create mode 100644 src/skelcast/models/rnn/pvred.py diff --git a/src/skelcast/models/rnn/pvred.py b/src/skelcast/models/rnn/pvred.py new file mode 100644 index 0000000..2f38a2e --- /dev/null +++ b/src/skelcast/models/rnn/pvred.py @@ -0,0 +1,240 @@ +import torch +import torch.nn as nn + +from skelcast.models import MODELS +from skelcast.models.module import SkelcastModule +from skelcast.models.transformers.base import PositionalEncoding + + +class Encoder(nn.Module): + def __init__(self, rnn_type: str = 'rnn', + input_dim: int = 75, + hidden_dim: int = 256, + batch_first: bool = True, + dropout: float = 0.2, + use_residual: bool = True) -> None: + super().__init__() + assert rnn_type in ['lstm', 'gru'], f'rnn_type must be one of rnn, lstm, gru, got {rnn_type}' + self.rnn_type = rnn_type + self.batch_first = batch_first + self.use_residual = use_residual + + if self.rnn_type == 'lstm': + self.rnn = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, batch_first=batch_first) + elif self.rnn_type == 'gru': + self.rnn = nn.GRU(input_size=input_dim, hidden_size=hidden_dim, batch_first=batch_first) + self.linear = nn.Linear(hidden_dim, input_dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out, hidden = self.rnn(x) + out = self.dropout(out) + out = self.linear(out) + if self.use_residual: + out = out + x + return out, hidden + + +class Decoder(nn.Module): + def __init__(self,rnn_type: str = 'rnn', + input_dim: int = 75, + hidden_dim: int = 256, + batch_first: bool = True, + dropout: float = 0.2, + use_residual: bool = True) -> None: + super().__init__() + assert rnn_type in ['lstm', 'gru'], f'rnn_type must be one of rnn, lstm, gru, got {rnn_type}' + self.rnn_type = rnn_type + self.batch_first = batch_first + self.use_residual = use_residual + + if self.rnn_type == 'lstm': + self.rnn = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, batch_first=batch_first) + elif self.rnn_type == 'gru': + self.rnn = nn.GRU(input_size=input_dim, hidden_size=hidden_dim, batch_first=batch_first) + self.linear = nn.Linear(hidden_dim, input_dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor, hidden: torch.Tensor = None) -> torch.Tensor: + out, _ = self.rnn(x, hidden) + out = self.dropout(out) + out = self.linear(out) + if self.use_residual: + out = out + x + return out + + +@MODELS.register_module() +class PositionalVelocityRecurrentEncoderDecoder(SkelcastModule): + """ + Positional Velocity Recurrent Encoder Decoder (PVRED) model. + The model was proposed in the paper: https://arxiv.org/abs/1906.06514 + + The model consists of an encoder and a decoder. The encoder is a recurrent neural network (RNN) + that encodes the input sequence into a latent representation. The decoder is also an RNN that + decodes the latent representation into a sequence of the same length as the input sequence. + + The authors of PVRED introduce the idea of positional encoding. Below, we describe how can the + positional encoding be applied to the input sequence. + + Three ways of taking into account the positional encoding: + 1. Concatenate the positional encoding to the input + 2. Add the positional encoding to the input + 3. NOP (no positional encoding) - if pos_enc is None + + Args: + --- + - `input_dim` (`int`): Input dimension of the model + - `enc_hidden_dim` (`int`): Hidden dimension of the encoder + - `dec_hidden_dim` (`int`): Hidden dimension of the decoder + - `enc_type` (`str`): Type of the encoder, one of lstm, gru + - `dec_type` (`str`): Type of the decoder, one of lstm, gru + - `include_velocity` (`bool`): Flag to indicate whether to include the velocity of the input + - `pos_enc` (`str`): Type of the positional encoding, one of concat, add, None + - `loss_fn` (`torch.nn.Module`): Loss function to be used for training + - `batch_first` (`bool`): Flag to indicate whether the batch dimension is the first dimension + - `std_thresh` (`float`): Threshold for the standard deviation of the input + - `use_std_mask` (`bool`): Flag to indicate whether to use the standard deviation mask + - `use_padded_len_mask` (`bool`): Flag to indicate whether to use the padded length mask + + Returns: + --- + - `dec_out` (`torch.Tensor`): Output of the decoder + - `loss` (`torch.Tensor`): Loss of the model + + Examples: + --- + ```python + import torch + from skelcast.models.rnn.pvred import PositionalVelocityRecurrentEncoderDecoder + + model = PositionalVelocityRecurrentEncoderDecoder(input_dim = 75) + x = torch.randn(32, 100, 75) + y = torch.randn(32, 100, 75) + dec_out, loss = model(x, y) + ``` + """ + def __init__(self, input_dim: int, enc_hidden_dim: int = 64, + dec_hidden_dim: int = 64, + enc_type: str = 'lstm', + dec_type: str = 'lstm', + include_velocity: bool = False, + pos_enc: str = None, + loss_fn: nn.Module = None, + batch_first: bool = True, + std_thresh: float = 1e-4, + use_std_mask: bool = False, + use_padded_len_mask: bool = False) -> None: + assert pos_enc in ['concat', 'add', None], f'pos_enc must be one of concat, add, None, got {pos_enc}' + assert isinstance(loss_fn, nn.Module), f'loss_fn must be an instance of torch.nn.Module, got {type(loss_fn)}' + assert enc_type in ['lstm', 'gru'], f'enc_type must be one of lstm, gru, got {enc_type}' + assert dec_type in ['lstm', 'gru'], f'dec_type must be one of lstm, gru, got {dec_type}' + super().__init__() + self.input_dim = input_dim + self.enc_hidden_dim = enc_hidden_dim + self.dec_hidden_dim = dec_hidden_dim + self.enc_type = enc_type + self.batch_first = batch_first + self.std_thresh = std_thresh + self.use_std_mask = use_std_mask + self.use_padded_len_mask = use_padded_len_mask + + self.include_velocity = include_velocity + if self.include_velocity: + # Double the input dimension because we are concatenating the xyz velocities to the input + self.input_dim = self.input_dim * 2 + + self.pos_enc_method = pos_enc + if pos_enc == 'concat' or pos_enc == 'add': + self.pos_enc = PositionalEncoding(input_dim) + if pos_enc == 'concat': + self.input_dim = self.input_dim + input_dim + else: + self.pos_enc = None + + self.loss_fn = loss_fn + + # Build encoder + self.encoder = Encoder(rnn_type=enc_type, input_dim=input_dim, + hidden_dim=enc_hidden_dim, batch_first=batch_first) + # Build decoder + self.decoder = Decoder(rnn_type=dec_type, input_dim=input_dim, + hidden_dim=dec_hidden_dim, batch_first=batch_first) + + + def forward(self, x: torch.Tensor, y: torch.Tensor, masks: torch.Tensor = None) -> torch.Tensor: + mask_pred = torch.std(x, dim=1) > self.std_thresh if self.batch_first else torch.std(x, dim=0) > self.std_thresh + # Calculate the velocity if the include_velocity flag is true + if self.include_velocity: + vel_inp = self._calculate_velocity(x) + vel_target = self._calculate_velocity(y) + # Concatenate the velocity to the input and the targets + x = torch.cat([x, vel_inp], dim=-1) + y = torch.cat([y, vel_target], dim=-1) + + # If the pos_enc is not None, apply the positional encoding, dependent on the pos_enc_method + + if self.pos_enc is not None: + if self.pos_enc_method == 'concat': + pass # TODO: Implement the concatenation of the positional encoding + elif self.pos_enc_method == 'add': + x += self.pos_enc.pe.repeat(1, x.shape[0], 1).permute(1, 0, 2)[:, :x.shape[1], :] + y += self.pos_enc.pe.repeat(1, y.shape[0], 1).permute(1, 0, 2)[:, :y.shape[1], :] + + # Encode the input + enc_out, enc_hidden = self.encoder(x) + # Decode the output + dec_out = self.decoder(enc_out, enc_hidden) + + # Calculate the loss + loss = self.loss_fn(dec_out, y) + + # Mask the loss + if self.use_std_mask: + loss = loss * mask_pred.float() + + # We mask the loss with the masks tensor if the use_padded_len_mask flag is true + # in order to suppress the loss contribution of the padded values + if self.use_padded_len_mask: + loss = loss * masks.float() + + return dec_out, loss + + def _calculate_velocity(self, x: torch.Tensor) -> torch.Tensor: + """ + Calculate the velocity of the input tensor + + Args: + --- + + - `x` (`torch.Tensor`): Input tensor of shape `(batch_size, seq_len, input_dim)` + + Returns: + --- + + - Velocity tensor of shape (batch_size, seq_len, input_dim) + """ + # Calculate the velocity + velocity = torch.zeros_like(x) + velocity[:, 1:, :] = x[:, 1:, :] - x[:, :-1, :] + return velocity + + def training_step(self, x: torch.Tensor, y: torch.Tensor) -> dict: + self.encoder.train() + self.decoder.train() + # Forward pass + dec_out, loss = self(x, y) + return {'loss': loss, 'out': dec_out} + + @torch.no_grad() + def validation_step(self, *args, **kwargs) -> dict: + self.encoder.eval() + self.decoder.eval() + return self.training_step(*args, **kwargs) + + + @torch.no_grad() + def predict(self, *args, **kwargs): + self.encoder.eval() + self.decoder.eval() + \ No newline at end of file From 0b3796d0878e6b2c6b776792358626a6775cc6ad Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Thu, 7 Dec 2023 16:37:15 +0200 Subject: [PATCH 4/9] PVRed model batched input support --- src/skelcast/models/rnn/pvred.py | 69 ++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 31 deletions(-) diff --git a/src/skelcast/models/rnn/pvred.py b/src/skelcast/models/rnn/pvred.py index 2f38a2e..b9ae283 100644 --- a/src/skelcast/models/rnn/pvred.py +++ b/src/skelcast/models/rnn/pvred.py @@ -1,11 +1,12 @@ import torch import torch.nn as nn -from skelcast.models import MODELS +from skelcast.models import MODELS, ENCODERS, DECODERS from skelcast.models.module import SkelcastModule from skelcast.models.transformers.base import PositionalEncoding +@ENCODERS.register_module() class Encoder(nn.Module): def __init__(self, rnn_type: str = 'rnn', input_dim: int = 75, @@ -35,6 +36,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out, hidden +@DECODERS.register_module() class Decoder(nn.Module): def __init__(self,rnn_type: str = 'rnn', input_dim: int = 75, @@ -55,13 +57,17 @@ def __init__(self,rnn_type: str = 'rnn', self.linear = nn.Linear(hidden_dim, input_dim) self.dropout = nn.Dropout(dropout) - def forward(self, x: torch.Tensor, hidden: torch.Tensor = None) -> torch.Tensor: - out, _ = self.rnn(x, hidden) - out = self.dropout(out) - out = self.linear(out) - if self.use_residual: - out = out + x - return out + def forward(self, x: torch.Tensor, hidden: torch.Tensor = None, timesteps_to_predict: int = 5) -> torch.Tensor: + predictions = [] + for _ in range(timesteps_to_predict): + out, hidden = self.rnn(x, hidden) + out = self.dropout(out) + out = self.linear(out) + if self.use_residual: + out = out + x + predictions.append(out) + x = out + return torch.cat(predictions, dim=1) @MODELS.register_module() @@ -96,6 +102,7 @@ class PositionalVelocityRecurrentEncoderDecoder(SkelcastModule): - `std_thresh` (`float`): Threshold for the standard deviation of the input - `use_std_mask` (`bool`): Flag to indicate whether to use the standard deviation mask - `use_padded_len_mask` (`bool`): Flag to indicate whether to use the padded length mask + - `observe_until` (`int`): The number of frames to observe before predicting the future Returns: --- @@ -124,7 +131,8 @@ def __init__(self, input_dim: int, enc_hidden_dim: int = 64, batch_first: bool = True, std_thresh: float = 1e-4, use_std_mask: bool = False, - use_padded_len_mask: bool = False) -> None: + use_padded_len_mask: bool = False, + observe_until: int = 30) -> None: assert pos_enc in ['concat', 'add', None], f'pos_enc must be one of concat, add, None, got {pos_enc}' assert isinstance(loss_fn, nn.Module), f'loss_fn must be an instance of torch.nn.Module, got {type(loss_fn)}' assert enc_type in ['lstm', 'gru'], f'enc_type must be one of lstm, gru, got {enc_type}' @@ -138,6 +146,7 @@ def __init__(self, input_dim: int, enc_hidden_dim: int = 64, self.std_thresh = std_thresh self.use_std_mask = use_std_mask self.use_padded_len_mask = use_padded_len_mask + self.observe_until = observe_until self.include_velocity = include_velocity if self.include_velocity: @@ -155,6 +164,7 @@ def __init__(self, input_dim: int, enc_hidden_dim: int = 64, self.loss_fn = loss_fn # Build encoder + # TODO: Convert them to registry-type build self.encoder = Encoder(rnn_type=enc_type, input_dim=input_dim, hidden_dim=enc_hidden_dim, batch_first=batch_first) # Build decoder @@ -162,42 +172,41 @@ def __init__(self, input_dim: int, enc_hidden_dim: int = 64, hidden_dim=dec_hidden_dim, batch_first=batch_first) - def forward(self, x: torch.Tensor, y: torch.Tensor, masks: torch.Tensor = None) -> torch.Tensor: - mask_pred = torch.std(x, dim=1) > self.std_thresh if self.batch_first else torch.std(x, dim=0) > self.std_thresh + def forward(self, x: torch.Tensor, masks: torch.Tensor = None) -> torch.Tensor: # Calculate the velocity if the include_velocity flag is true if self.include_velocity: vel_inp = self._calculate_velocity(x) - vel_target = self._calculate_velocity(y) # Concatenate the velocity to the input and the targets x = torch.cat([x, vel_inp], dim=-1) - y = torch.cat([y, vel_target], dim=-1) # If the pos_enc is not None, apply the positional encoding, dependent on the pos_enc_method if self.pos_enc is not None: if self.pos_enc_method == 'concat': - pass # TODO: Implement the concatenation of the positional encoding + raise NotImplementedError('Concat positional encoding is not implemented yet') elif self.pos_enc_method == 'add': - x += self.pos_enc.pe.repeat(1, x.shape[0], 1).permute(1, 0, 2)[:, :x.shape[1], :] - y += self.pos_enc.pe.repeat(1, y.shape[0], 1).permute(1, 0, 2)[:, :y.shape[1], :] - + x += self.pos_enc.pe.repeat(1, x.shape[0], 1).permute(1, 0, 2) + + encoder_input, decoder_initial_value, targets = x[:, :self.observe_until, :], x[:, self.observe_until, :], x[:, self.observe_until:, :] + mask_pred = torch.std(x, dim=1) > self.std_thresh if self.batch_first else torch.std(x, dim=0) > self.std_thresh # Encode the input - enc_out, enc_hidden = self.encoder(x) + enc_out, enc_hidden = self.encoder(encoder_input) # Decode the output - dec_out = self.decoder(enc_out, enc_hidden) + dec_out = self.decoder(decoder_initial_value.unsqueeze(1), enc_hidden, timesteps_to_predict=targets.shape[1]) + + # The decoder's output should have a shape of (batch_size, seq_len, input_dim) + assert dec_out.shape == targets.shape, f'dec_out.shape must be equal to targets.shape, got {dec_out.shape} and {targets.shape}' + # Apply the padded length masks to the prediction + if self.use_padded_len_mask: + dec_out = dec_out * masks.float() + + # Apply the std masks to the prediction + if self.use_std_mask: + dec_out = dec_out * mask_pred.float() # Calculate the loss - loss = self.loss_fn(dec_out, y) + loss = self.loss_fn(dec_out, targets) - # Mask the loss - if self.use_std_mask: - loss = loss * mask_pred.float() - - # We mask the loss with the masks tensor if the use_padded_len_mask flag is true - # in order to suppress the loss contribution of the padded values - if self.use_padded_len_mask: - loss = loss * masks.float() - return dec_out, loss def _calculate_velocity(self, x: torch.Tensor) -> torch.Tensor: @@ -206,12 +215,10 @@ def _calculate_velocity(self, x: torch.Tensor) -> torch.Tensor: Args: --- - - `x` (`torch.Tensor`): Input tensor of shape `(batch_size, seq_len, input_dim)` Returns: --- - - Velocity tensor of shape (batch_size, seq_len, input_dim) """ # Calculate the velocity From 4993bfb502245ee09a333ac32c8c3553f06edc49 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Thu, 7 Dec 2023 16:37:50 +0200 Subject: [PATCH 5/9] Register encoders and decoders. --- src/skelcast/models/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/skelcast/models/__init__.py b/src/skelcast/models/__init__.py index 25ab7af..757bcf8 100644 --- a/src/skelcast/models/__init__.py +++ b/src/skelcast/models/__init__.py @@ -2,7 +2,10 @@ from skelcast.core.registry import Registry MODELS = Registry() +ENCODERS = Registry() +DECODERS = Registry() from .rnn.lstm import SimpleLSTMRegressor from .transformers.transformer import ForecastTransformer -from .rnn.pvred import PositionalVelocityRecurrentEncoderDecoder \ No newline at end of file +from .rnn.pvred import PositionalVelocityRecurrentEncoderDecoder +from .rnn.pvred import Encoder, Decoder \ No newline at end of file From 994fbfabe5f727084ea4d00eb692b9d04e7d5a55 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Thu, 7 Dec 2023 16:43:53 +0200 Subject: [PATCH 6/9] Update docstring example --- src/skelcast/models/rnn/pvred.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/skelcast/models/rnn/pvred.py b/src/skelcast/models/rnn/pvred.py index b9ae283..eca7d03 100644 --- a/src/skelcast/models/rnn/pvred.py +++ b/src/skelcast/models/rnn/pvred.py @@ -115,10 +115,11 @@ class PositionalVelocityRecurrentEncoderDecoder(SkelcastModule): import torch from skelcast.models.rnn.pvred import PositionalVelocityRecurrentEncoderDecoder - model = PositionalVelocityRecurrentEncoderDecoder(input_dim = 75) + model = PositionalVelocityRecurrentEncoderDecoder(input_dim = 75, observe_until=80) x = torch.randn(32, 100, 75) - y = torch.randn(32, 100, 75) - dec_out, loss = model(x, y) + dec_out, loss = model(x) + dec_out.shape + >>> torch.Size([32, 20, 75]) ``` """ def __init__(self, input_dim: int, enc_hidden_dim: int = 64, From fe4fb85011b9e099d6993baef4270b79fb2bcf7a Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Fri, 8 Dec 2023 12:08:43 +0200 Subject: [PATCH 7/9] Concat/add positional embedding support. --- src/skelcast/models/rnn/pvred.py | 14 ++++++-------- src/skelcast/models/transformers/base.py | 16 +++++++++++----- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/skelcast/models/rnn/pvred.py b/src/skelcast/models/rnn/pvred.py index eca7d03..9251bbb 100644 --- a/src/skelcast/models/rnn/pvred.py +++ b/src/skelcast/models/rnn/pvred.py @@ -155,10 +155,11 @@ def __init__(self, input_dim: int, enc_hidden_dim: int = 64, self.input_dim = self.input_dim * 2 self.pos_enc_method = pos_enc - if pos_enc == 'concat' or pos_enc == 'add': - self.pos_enc = PositionalEncoding(input_dim) - if pos_enc == 'concat': - self.input_dim = self.input_dim + input_dim + if self.pos_enc_method == 'concat': + self.pos_enc = PositionalEncoding(input_dim, mode='concat') + self.input_dim = self.input_dim + input_dim + elif self.pos_enc_method == 'add': + self.pos_enc = PositionalEncoding(input_dim, mode='add') else: self.pos_enc = None @@ -183,10 +184,7 @@ def forward(self, x: torch.Tensor, masks: torch.Tensor = None) -> torch.Tensor: # If the pos_enc is not None, apply the positional encoding, dependent on the pos_enc_method if self.pos_enc is not None: - if self.pos_enc_method == 'concat': - raise NotImplementedError('Concat positional encoding is not implemented yet') - elif self.pos_enc_method == 'add': - x += self.pos_enc.pe.repeat(1, x.shape[0], 1).permute(1, 0, 2) + x = self.pos_enc(x) encoder_input, decoder_initial_value, targets = x[:, :self.observe_until, :], x[:, self.observe_until, :], x[:, self.observe_until:, :] mask_pred = torch.std(x, dim=1) > self.std_thresh if self.batch_first else torch.std(x, dim=0) > self.std_thresh diff --git a/src/skelcast/models/transformers/base.py b/src/skelcast/models/transformers/base.py index 6315ecc..4846804 100644 --- a/src/skelcast/models/transformers/base.py +++ b/src/skelcast/models/transformers/base.py @@ -60,8 +60,10 @@ def forward(self, x): class PositionalEncoding(nn.Module): - def __init__(self, d_model, max_len=5000): + def __init__(self, d_model, max_len=5000, mode='add'): super(PositionalEncoding, self).__init__() + self.mode = mode + # Create a long enough 'PE' matrix with position and dimension indexes pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) @@ -76,18 +78,22 @@ def __init__(self, d_model, max_len=5000): if d_model % 2 == 0: pe[:, 1::2] = torch.cos(position * div_term) else: - # For odd d_model, the last cosine term needs to be calculated differently pe[:, 1::2] = torch.cos(position * div_term[:-1]) last_cos = torch.cos(position * div_term[-1]) pe[:, -1] = last_cos.squeeze() pe = pe.unsqueeze(0).transpose(0, 1) - # Registers pe as a buffer that should not be considered a model parameter. self.register_buffer('pe', pe) def forward(self, x): - # Adds the positional encoding vector to the input embedding vector - x = x + self.pe[:x.size(0), :] + if self.mode == 'add': + # Adds the positional encoding vector to the input embedding vector + x = x + self.pe[:x.size(0), :] + elif self.mode == 'concat': + # Concatenates the positional encoding vector with the input embedding vector + x = torch.cat((x, self.pe[:x.size(0), :]), dim=-1) + else: + raise ValueError("Invalid mode. Choose 'add' or 'concat'.") return x From 35328cba69131ed982d74e6a751f9d2ccc5829d4 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Fri, 8 Dec 2023 12:28:16 +0200 Subject: [PATCH 8/9] Batch support for positional encoding --- src/skelcast/models/rnn/pvred.py | 8 ++++---- src/skelcast/models/transformers/base.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/skelcast/models/rnn/pvred.py b/src/skelcast/models/rnn/pvred.py index 9251bbb..dd2184a 100644 --- a/src/skelcast/models/rnn/pvred.py +++ b/src/skelcast/models/rnn/pvred.py @@ -156,10 +156,10 @@ def __init__(self, input_dim: int, enc_hidden_dim: int = 64, self.pos_enc_method = pos_enc if self.pos_enc_method == 'concat': - self.pos_enc = PositionalEncoding(input_dim, mode='concat') + self.pos_enc = PositionalEncoding(d_model=input_dim, mode='concat') self.input_dim = self.input_dim + input_dim elif self.pos_enc_method == 'add': - self.pos_enc = PositionalEncoding(input_dim, mode='add') + self.pos_enc = PositionalEncoding(d_model=input_dim, mode='add') else: self.pos_enc = None @@ -167,10 +167,10 @@ def __init__(self, input_dim: int, enc_hidden_dim: int = 64, # Build encoder # TODO: Convert them to registry-type build - self.encoder = Encoder(rnn_type=enc_type, input_dim=input_dim, + self.encoder = Encoder(rnn_type=enc_type, input_dim=self.input_dim, hidden_dim=enc_hidden_dim, batch_first=batch_first) # Build decoder - self.decoder = Decoder(rnn_type=dec_type, input_dim=input_dim, + self.decoder = Decoder(rnn_type=dec_type, input_dim=self.input_dim, hidden_dim=dec_hidden_dim, batch_first=batch_first) diff --git a/src/skelcast/models/transformers/base.py b/src/skelcast/models/transformers/base.py index 4846804..e0b3ad2 100644 --- a/src/skelcast/models/transformers/base.py +++ b/src/skelcast/models/transformers/base.py @@ -82,16 +82,16 @@ def __init__(self, d_model, max_len=5000, mode='add'): last_cos = torch.cos(position * div_term[-1]) pe[:, -1] = last_cos.squeeze() - pe = pe.unsqueeze(0).transpose(0, 1) + pe = pe.unsqueeze(0) self.register_buffer('pe', pe) def forward(self, x): if self.mode == 'add': # Adds the positional encoding vector to the input embedding vector - x = x + self.pe[:x.size(0), :] + x = x + self.pe[:x.size(1), :].repeat(x.size(0), 1, 1) elif self.mode == 'concat': # Concatenates the positional encoding vector with the input embedding vector - x = torch.cat((x, self.pe[:x.size(0), :]), dim=-1) + x = torch.cat((x, self.pe[:, :x.size(1), :].repeat(x.size(0), 1, 1)), dim=-1) else: raise ValueError("Invalid mode. Choose 'add' or 'concat'.") return x From 11413e77be8260b2fd7e6663faf4e397d0b5c602 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Fri, 8 Dec 2023 13:48:26 +0200 Subject: [PATCH 9/9] Collate function that randomly samples scene segments --- src/skelcast/data/dataset.py | 48 ++++++++++++++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 2 deletions(-) diff --git a/src/skelcast/data/dataset.py b/src/skelcast/data/dataset.py index 62e81eb..a50f34b 100644 --- a/src/skelcast/data/dataset.py +++ b/src/skelcast/data/dataset.py @@ -2,7 +2,7 @@ import logging import pickle from dataclasses import dataclass -from typing import Any, Tuple, List +from typing import Any, Tuple, List, Optional import numpy as np import torch @@ -92,6 +92,7 @@ class NTURGBDSample: x: torch.tensor y: torch.tensor label: Tuple[int, str] + mask: Optional[torch.tensor] = None def nturbgd_collate_fn_with_overlapping_context_window(batch: List[NTURGBDSample]) -> NTURGBDSample: # TODO: Normalize each sample individually along its 3 axes @@ -102,7 +103,7 @@ def nturbgd_collate_fn_with_overlapping_context_window(batch: List[NTURGBDSample # batch_x = default_collate(batch_x) # batch_y = default_collate(batch_y) batch_label = default_collate(batch_label) - return NTURGBDSample(x=batch_x, y=batch_y, label=batch_label) + return NTURGBDSample(x=batch_x, y=batch_y, label=batch_label, mask=None) @COLLATE_FUNCS.register_module() @@ -161,7 +162,50 @@ def get_windows(self, x): input_windows_tensor = torch.tensor(input_windows, dtype=torch.float) target_labels_tensor = torch.tensor(np.array(target_labels), dtype=torch.float) return input_windows_tensor, target_labels_tensor + + +@COLLATE_FUNCS.register_module() +class NTURGBDCollateFnWithRandomSampledContextWindow: + """ + Custom collate function for batched variable-length sequences. + During the __call__ function, we creata `block_size`-long context windows, for each sequence in the batch. + If is_packed is True, we pack the padded sequences, otherwise we return the padded sequences as is. + + Args: + - block_size (int): Sequence's context length. + - is_packed (bool): Whether to pack the padded sequence or not. + + Returns: + + The batched padded sequences ready to be fed to a transformer or an lstm model. + """ + def __init__(self, block_size: int) -> None: + self.block_size = block_size + def __call__(self, batch) -> NTURGBDSample: + # Pick a random index for each element of the batch and create a context window of size `block_size` + # around that index + # If the batch element's sequence length is less than `block_size`, then we sample the entire sequence + # Pick the random index using pytorch + seq_lens = [sample.shape[0] for sample, _ in batch] + labels = [label for _, label in batch] + pre_batch = [] + for sample, _ in batch: + logging.debug(f'sample.shape: {sample.shape}') + if sample.shape[0] <= self.block_size: + # Sample the entire sequence + logging.debug(f'Detected a sample with a sample length of {sample.shape[0]}') + pre_batch.append(sample) + else: + # Sample a random index + idx = torch.randint(low=0, high=sample.shape[0] - self.block_size, size=(1,)).item() + pre_batch.append(sample[idx:idx + self.block_size, ...]) + # Pad the sequences to the maximum sequence length in the batch + batch_x = torch.nn.utils.rnn.pad_sequence(pre_batch, batch_first=True) + # Generate masks + masks = torch.nn.utils.rnn.pack_sequence([torch.ones(seq_len) for seq_len in seq_lens], enforce_sorted=False).to(torch.float32) + return NTURGBDSample(x=batch_x, y=batch_x, label=labels, mask=masks) + @DATASETS.register_module() class NTURGBDDataset(Dataset):