-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmodel.py
100 lines (71 loc) · 2.82 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import torch.nn as nn
import torchvision.models as models
from modules import ASPPModule, DecoderModule, SEModule
from torchsummary import summary
from typing import Any
class DeepLabV3Plus(nn.Module):
def __init__(self, num_classes: int = 1) -> None:
super(DeepLabV3Plus, self).__init__()
resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
self.backbone = nn.Sequential(*list(resnet.children())[:-2])
in_channels = 1024
out_channels = 256
# Dilation Rates
dilations = [6, 12, 18, 24]
# ASPP Module
self.aspp = ASPPModule(in_channels, out_channels, dilations)
# Decoder Module
self.decoder = DecoderModule(out_channels, out_channels)
# Upsampling with Bilinear Interpolation
self.upsample = nn.UpsamplingBilinear2d(scale_factor=4)
# Dropout
self.dropout = nn.Dropout(p=0.5)
# Final 1x1 Convolution
self.final_conv = nn.Conv2d(out_channels, num_classes, kernel_size=1)
# Sigmoid Activation for Binary-Seg
self.sigmoid = nn.Sigmoid()
# self.tanh = nn.Tanh()
# Initialize weights
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x: Any) -> Any:
# DeepLabV3+ Forward Pass
# Getting Low-Level Features
x_low = self.backbone[:-3](x)
# Getting Image Features from Backbone
x = self.backbone[:-1](x)
# ASPP forward pass - High-Level Features
x = self.aspp(x)
# Upsampling High-Level Features
x = self.upsample(x)
x = self.dropout(x)
# Decoder forward pass - Concatenating Features
x = self.decoder(x, x_low)
# Upsampling Concatenated Features from Decoder
x = self.upsample(x)
# Final 1x1 Convolution for Binary-Segmentation
x = self.final_conv(x)
x = self.sigmoid(x)
# x = self.tanh(x)
# For Tanh
# normalized_x = (x + 1) * 0.5
return x
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DeepLabV3Plus(num_classes=1) # For binary segmentation, num_classes=1
model.to(device)
# Random input tensor for testing
batch_size = 2
input_channels = 3
height, width = 256, 256
random_input = torch.randn(batch_size, input_channels, height, width).to(device)
# Forward pass
output = model(random_input)
print("Output shape:", output.shape)
summary(model, input_size=(3, 256, 256))