From 63f939fc2c7e5428a65554413af5eba3cbc04b04 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Thu, 21 Dec 2023 14:10:18 +0200 Subject: [PATCH 1/2] Compute the dead neurons --- src/skelcast/experiments/runner.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/skelcast/experiments/runner.py b/src/skelcast/experiments/runner.py index 5a3d125..6ae305e 100644 --- a/src/skelcast/experiments/runner.py +++ b/src/skelcast/experiments/runner.py @@ -192,6 +192,9 @@ def training_step(self, train_batch: NTURGBDSample): # Calculate the saturation of the tanh output saturated = (outputs.abs() > 0.95) saturation_percentage = saturated.sum(dim=(1, 2)).float() / (outputs.size(1) * outputs.size(2)) * 100 + # Calculate the dead neurons + dead_neurons = (outputs.abs() < 0.05) + dead_neurons_percentage = dead_neurons.sum(dim=(1, 2)).float() / (outputs.size(1) * outputs.size(2)) * 100 self.optimizer.zero_grad() loss.backward() if self.log_gradient_info: @@ -211,6 +214,7 @@ def training_step(self, train_batch: NTURGBDSample): if self.logger is not None: self.logger.add_scalar(tag='train/saturation', scalar_value=saturation_percentage.mean().item(), global_step=len(self.training_loss_per_step)) + self.logger.add_scalar(tag='train/dead_neurons', scalar_value=dead_neurons_percentage.mean().item(), global_step=len(self.training_loss_per_step)) self.optimizer.step() # Print the loss From c36f8972994b60faa083b02cce68d6c6208e639e Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Thu, 21 Dec 2023 14:18:44 +0200 Subject: [PATCH 2/2] UNet module. --- src/skelcast/models/__init__.py | 3 +- src/skelcast/models/cnn/__init__.py | 0 src/skelcast/models/cnn/unet.py | 95 +++++++++++++++++++++++++++++ 3 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 src/skelcast/models/cnn/__init__.py create mode 100644 src/skelcast/models/cnn/unet.py diff --git a/src/skelcast/models/__init__.py b/src/skelcast/models/__init__.py index 757bcf8..70e1e90 100644 --- a/src/skelcast/models/__init__.py +++ b/src/skelcast/models/__init__.py @@ -8,4 +8,5 @@ from .rnn.lstm import SimpleLSTMRegressor from .transformers.transformer import ForecastTransformer from .rnn.pvred import PositionalVelocityRecurrentEncoderDecoder -from .rnn.pvred import Encoder, Decoder \ No newline at end of file +from .rnn.pvred import Encoder, Decoder +from .cnn.unet import Unet \ No newline at end of file diff --git a/src/skelcast/models/cnn/__init__.py b/src/skelcast/models/cnn/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/skelcast/models/cnn/unet.py b/src/skelcast/models/cnn/unet.py new file mode 100644 index 0000000..18a0a1c --- /dev/null +++ b/src/skelcast/models/cnn/unet.py @@ -0,0 +1,95 @@ +import torch +import torch.nn as nn + +from skelcast.models import MODELS +from skelcast.models.module import SkelcastModule + + +class Conv2D(nn.Module): + def __init__(self, in_channels, out_channels, kernel, stride=1): + super().__init__() + self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel, stride=stride, padding='same', padding_mode='reflect', bias=False) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.PReLU(out_channels) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class UpConv2D(nn.Module): + def __init__(self, in_channels, out_channels, kernel, mode='bilinear'): + super().__init__() + self.us = nn.Upsample(scale_factor=1, mode=mode) + self.conv = Conv2D(in_channels, out_channels, kernel, stride=1) + + def forward(self, x): + x = self.us(x) + x = self.conv(x) + return x + + +class DownConv2D(nn.Module): + def __init__(self, in_channels, out_channels, kernel): + super().__init__() + self.conv = Conv2D(in_channels, out_channels, kernel, stride=1) + + def forward(self, x): + return self.conv(x) + + +class CatConv2D(nn.Module): + def __init__(self, in_channels, out_channels, kernel=1): + super().__init__() + self.conv = Conv2D(in_channels, out_channels, kernel, stride=1) + + def forward(self, x1, x2): + return self.conv(torch.cat([x1, x2], dim=1)) + + +@MODELS.register_module() +class Unet(SkelcastModule): + def __init__(self, filters=64, seq_size=50, out_size=5): + super().__init__() + # Decoder + self.c1 = Conv2D(seq_size, filters, 1) + self.c2 = DownConv2D(filters, filters * 2, 1) + self.c3 = DownConv2D(filters * 2, filters * 4, 1) + self.c4 = DownConv2D(filters * 4, filters * 8, 1) + # Bottleneck + self.c5 = DownConv2D(filters * 8, filters * 8, 1) + # Encoder + self.u1 = UpConv2D(filters * 8, filters * 8, 1, mode='bilinear') + self.cc1 = CatConv2D(filters * 16, filters * 8, 1) + self.u2 = UpConv2D(filters * 8, filters * 4, 1, mode='bilinear') + self.cc2 = CatConv2D(filters * 8, filters * 4, 1) + self.u3 = UpConv2D(filters * 4, filters * 2, 1, mode='bilinear') + self.cc3 = CatConv2D(filters * 4, filters * 2, 1) + self.u4 = UpConv2D(filters * 2, filters, 1, mode='bilinear') + self.cc4 = CatConv2D(filters * 2, filters, 1) + self.outconv = Conv2D(filters, out_size, 1) + + def forward(self, x): + x1 = self.c1(x) + print(f'x1 shape: {x1.shape}') + x2 = self.c2(x1) + print(f'x2 shape: {x2.shape}') + x3 = self.c3(x2) + print(f'x3 shape: {x3.shape}') + x4 = self.c4(x3) + print(f'x4 shape: {x4.shape}') + x5 = self.c5(x4) + print(f'x5 shape: {x5.shape}') + x = self.u1(x5) + print(f'u1 shape: {x.shape}') + x = self.cc1(x, x4) + x = self.u2(x) + x = self.cc2(x, x3) + x = self.u3(x) + x = self.cc3(x, x2) + x = self.u4(x) + x = self.cc4(x, x1) + x = self.outconv(x) + return x \ No newline at end of file