diff --git a/src/skelcast/models/cnn/unet.py b/src/skelcast/models/cnn/unet.py index 18a0a1c..d184b1b 100644 --- a/src/skelcast/models/cnn/unet.py +++ b/src/skelcast/models/cnn/unet.py @@ -51,7 +51,24 @@ def forward(self, x1, x2): @MODELS.register_module() class Unet(SkelcastModule): - def __init__(self, filters=64, seq_size=50, out_size=5): + """ + A really nice implementation of the U-Net architecture of the architecture implemented in the paper: + Accurate Monitoring of 24-h Real-World Movement Behavior in People with Cerebral Palsy Is Possible Using Multiple Wearable Sensors and Deep Learning. + https://www.mdpi.com/1424-8220/23/22/9045 + + Credits to its creator: Georgios Zampoukis + Args: + filters (int): Number of filters to use in the convolutional layers. + seq_size (int): Number of frames in the input sequence. + out_size (int): Number of output channels. + loss_fn (nn.Module): Loss function to use. + observe_until (int): Number of frames to observe before predicting. + ts_to_predict (int): Number of frames to predict. + use_padded_len_mask (bool): Whether to use a mask to ignore padded values. + """ + def __init__(self, filters=64, seq_size=50, out_size=5, loss_fn: nn.Module = None, + observe_until: int = 20, ts_to_predict: int = 5, + use_padded_len_mask: bool = False): super().__init__() # Decoder self.c1 = Conv2D(seq_size, filters, 1) @@ -69,21 +86,20 @@ def __init__(self, filters=64, seq_size=50, out_size=5): self.cc3 = CatConv2D(filters * 4, filters * 2, 1) self.u4 = UpConv2D(filters * 2, filters, 1, mode='bilinear') self.cc4 = CatConv2D(filters * 2, filters, 1) - self.outconv = Conv2D(filters, out_size, 1) + self.outconv = Conv2D(filters, out_size, 1) + + self.loss_fn = loss_fn if loss_fn is not None else nn.SmoothL1Loss() + self.observe_until = observe_until + self.ts_to_predict = ts_to_predict + self.use_padded_len_mask = use_padded_len_mask def forward(self, x): x1 = self.c1(x) - print(f'x1 shape: {x1.shape}') x2 = self.c2(x1) - print(f'x2 shape: {x2.shape}') x3 = self.c3(x2) - print(f'x3 shape: {x3.shape}') x4 = self.c4(x3) - print(f'x4 shape: {x4.shape}') x5 = self.c5(x4) - print(f'x5 shape: {x5.shape}') x = self.u1(x5) - print(f'u1 shape: {x.shape}') x = self.cc1(x, x4) x = self.u2(x) x = self.cc2(x, x3) @@ -92,4 +108,29 @@ def forward(self, x): x = self.u4(x) x = self.cc4(x, x1) x = self.outconv(x) - return x \ No newline at end of file + return x + + def training_step(self, x: torch.Tensor, y: torch.Tensor = None, mask: torch.Tensor = None) -> dict: + batch_size, seq_len, n_skels, n_joints, dims = x.shape + x = x.view(batch_size, seq_len, n_joints, dims) + x_observe = x[:, :self.observe_until, :, :] + y = x[:, self.observe_until:self.observe_until + self.ts_to_predict, :, :] + # View the mask as the x tensor + if self.use_padded_len_mask: + mask = mask.view(batch_size, seq_len, n_joints, dims) + mask = mask[:, self.observe_until:self.observe_until + self.ts_to_predict, :] + out = self(x_observe) + if self.use_padded_len_mask: + out = out * mask + loss = self.loss_fn(out, y) + return {'out': out, 'loss': loss} + + @torch.no_grad() + def validation_step(self, x, y) -> dict: + out = self(x) + loss = self.loss_fn(out, y) + return {'out': out, 'loss': loss} + + @torch.no_grad() + def predict(self): + pass \ No newline at end of file