Skip to content

Commit

Permalink
Merge pull request #57 from kaseris/models/unet
Browse files Browse the repository at this point in the history
Models/unet
  • Loading branch information
kaseris committed Dec 21, 2023
2 parents 9408041 + c36f897 commit e90e2c7
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/skelcast/experiments/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ def training_step(self, train_batch: NTURGBDSample):
# Calculate the saturation of the tanh output
saturated = (outputs.abs() > 0.95)
saturation_percentage = saturated.sum(dim=(1, 2)).float() / (outputs.size(1) * outputs.size(2)) * 100
# Calculate the dead neurons
dead_neurons = (outputs.abs() < 0.05)
dead_neurons_percentage = dead_neurons.sum(dim=(1, 2)).float() / (outputs.size(1) * outputs.size(2)) * 100
self.optimizer.zero_grad()
loss.backward()
if self.log_gradient_info:
Expand All @@ -211,6 +214,7 @@ def training_step(self, train_batch: NTURGBDSample):

if self.logger is not None:
self.logger.add_scalar(tag='train/saturation', scalar_value=saturation_percentage.mean().item(), global_step=len(self.training_loss_per_step))
self.logger.add_scalar(tag='train/dead_neurons', scalar_value=dead_neurons_percentage.mean().item(), global_step=len(self.training_loss_per_step))

self.optimizer.step()
# Print the loss
Expand Down
3 changes: 2 additions & 1 deletion src/skelcast/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
from .rnn.lstm import SimpleLSTMRegressor
from .transformers.transformer import ForecastTransformer
from .rnn.pvred import PositionalVelocityRecurrentEncoderDecoder
from .rnn.pvred import Encoder, Decoder
from .rnn.pvred import Encoder, Decoder
from .cnn.unet import Unet
Empty file.
95 changes: 95 additions & 0 deletions src/skelcast/models/cnn/unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
import torch.nn as nn

from skelcast.models import MODELS
from skelcast.models.module import SkelcastModule


class Conv2D(nn.Module):
def __init__(self, in_channels, out_channels, kernel, stride=1):
super().__init__()
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel, stride=stride, padding='same', padding_mode='reflect', bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.PReLU(out_channels)

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x


class UpConv2D(nn.Module):
def __init__(self, in_channels, out_channels, kernel, mode='bilinear'):
super().__init__()
self.us = nn.Upsample(scale_factor=1, mode=mode)
self.conv = Conv2D(in_channels, out_channels, kernel, stride=1)

def forward(self, x):
x = self.us(x)
x = self.conv(x)
return x


class DownConv2D(nn.Module):
def __init__(self, in_channels, out_channels, kernel):
super().__init__()
self.conv = Conv2D(in_channels, out_channels, kernel, stride=1)

def forward(self, x):
return self.conv(x)


class CatConv2D(nn.Module):
def __init__(self, in_channels, out_channels, kernel=1):
super().__init__()
self.conv = Conv2D(in_channels, out_channels, kernel, stride=1)

def forward(self, x1, x2):
return self.conv(torch.cat([x1, x2], dim=1))


@MODELS.register_module()
class Unet(SkelcastModule):
def __init__(self, filters=64, seq_size=50, out_size=5):
super().__init__()
# Decoder
self.c1 = Conv2D(seq_size, filters, 1)
self.c2 = DownConv2D(filters, filters * 2, 1)
self.c3 = DownConv2D(filters * 2, filters * 4, 1)
self.c4 = DownConv2D(filters * 4, filters * 8, 1)
# Bottleneck
self.c5 = DownConv2D(filters * 8, filters * 8, 1)
# Encoder
self.u1 = UpConv2D(filters * 8, filters * 8, 1, mode='bilinear')
self.cc1 = CatConv2D(filters * 16, filters * 8, 1)
self.u2 = UpConv2D(filters * 8, filters * 4, 1, mode='bilinear')
self.cc2 = CatConv2D(filters * 8, filters * 4, 1)
self.u3 = UpConv2D(filters * 4, filters * 2, 1, mode='bilinear')
self.cc3 = CatConv2D(filters * 4, filters * 2, 1)
self.u4 = UpConv2D(filters * 2, filters, 1, mode='bilinear')
self.cc4 = CatConv2D(filters * 2, filters, 1)
self.outconv = Conv2D(filters, out_size, 1)

def forward(self, x):
x1 = self.c1(x)
print(f'x1 shape: {x1.shape}')
x2 = self.c2(x1)
print(f'x2 shape: {x2.shape}')
x3 = self.c3(x2)
print(f'x3 shape: {x3.shape}')
x4 = self.c4(x3)
print(f'x4 shape: {x4.shape}')
x5 = self.c5(x4)
print(f'x5 shape: {x5.shape}')
x = self.u1(x5)
print(f'u1 shape: {x.shape}')
x = self.cc1(x, x4)
x = self.u2(x)
x = self.cc2(x, x3)
x = self.u3(x)
x = self.cc3(x, x2)
x = self.u4(x)
x = self.cc4(x, x1)
x = self.outconv(x)
return x

0 comments on commit e90e2c7

Please sign in to comment.