Skip to content

Commit

Permalink
Merge pull request #58 from kaseris/models/unet
Browse files Browse the repository at this point in the history
Models/unet
  • Loading branch information
kaseris committed Dec 21, 2023
2 parents e90e2c7 + 3c35f1c commit 2c98825
Showing 1 changed file with 50 additions and 9 deletions.
59 changes: 50 additions & 9 deletions src/skelcast/models/cnn/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,24 @@ def forward(self, x1, x2):

@MODELS.register_module()
class Unet(SkelcastModule):
def __init__(self, filters=64, seq_size=50, out_size=5):
"""
A really nice implementation of the U-Net architecture of the architecture implemented in the paper:
Accurate Monitoring of 24-h Real-World Movement Behavior in People with Cerebral Palsy Is Possible Using Multiple Wearable Sensors and Deep Learning.
https://www.mdpi.com/1424-8220/23/22/9045
Credits to its creator: Georgios Zampoukis
Args:
filters (int): Number of filters to use in the convolutional layers.
seq_size (int): Number of frames in the input sequence.
out_size (int): Number of output channels.
loss_fn (nn.Module): Loss function to use.
observe_until (int): Number of frames to observe before predicting.
ts_to_predict (int): Number of frames to predict.
use_padded_len_mask (bool): Whether to use a mask to ignore padded values.
"""
def __init__(self, filters=64, seq_size=50, out_size=5, loss_fn: nn.Module = None,
observe_until: int = 20, ts_to_predict: int = 5,
use_padded_len_mask: bool = False):
super().__init__()
# Decoder
self.c1 = Conv2D(seq_size, filters, 1)
Expand All @@ -69,21 +86,20 @@ def __init__(self, filters=64, seq_size=50, out_size=5):
self.cc3 = CatConv2D(filters * 4, filters * 2, 1)
self.u4 = UpConv2D(filters * 2, filters, 1, mode='bilinear')
self.cc4 = CatConv2D(filters * 2, filters, 1)
self.outconv = Conv2D(filters, out_size, 1)
self.outconv = Conv2D(filters, out_size, 1)

self.loss_fn = loss_fn if loss_fn is not None else nn.SmoothL1Loss()
self.observe_until = observe_until
self.ts_to_predict = ts_to_predict
self.use_padded_len_mask = use_padded_len_mask

def forward(self, x):
x1 = self.c1(x)
print(f'x1 shape: {x1.shape}')
x2 = self.c2(x1)
print(f'x2 shape: {x2.shape}')
x3 = self.c3(x2)
print(f'x3 shape: {x3.shape}')
x4 = self.c4(x3)
print(f'x4 shape: {x4.shape}')
x5 = self.c5(x4)
print(f'x5 shape: {x5.shape}')
x = self.u1(x5)
print(f'u1 shape: {x.shape}')
x = self.cc1(x, x4)
x = self.u2(x)
x = self.cc2(x, x3)
Expand All @@ -92,4 +108,29 @@ def forward(self, x):
x = self.u4(x)
x = self.cc4(x, x1)
x = self.outconv(x)
return x
return x

def training_step(self, x: torch.Tensor, y: torch.Tensor = None, mask: torch.Tensor = None) -> dict:
batch_size, seq_len, n_skels, n_joints, dims = x.shape
x = x.view(batch_size, seq_len, n_joints, dims)
x_observe = x[:, :self.observe_until, :, :]
y = x[:, self.observe_until:self.observe_until + self.ts_to_predict, :, :]
# View the mask as the x tensor
if self.use_padded_len_mask:
mask = mask.view(batch_size, seq_len, n_joints, dims)
mask = mask[:, self.observe_until:self.observe_until + self.ts_to_predict, :]
out = self(x_observe)
if self.use_padded_len_mask:
out = out * mask
loss = self.loss_fn(out, y)
return {'out': out, 'loss': loss}

@torch.no_grad()
def validation_step(self, x, y) -> dict:
out = self(x)
loss = self.loss_fn(out, y)
return {'out': out, 'loss': loss}

@torch.no_grad()
def predict(self):
pass

0 comments on commit 2c98825

Please sign in to comment.