-
Notifications
You must be signed in to change notification settings - Fork 1
/
densenet.py
140 lines (118 loc) · 4.32 KB
/
densenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
DenseNet implementation in keras
"""
from typing import Union, Tuple
from keras import Model, Input
from keras.layers import GlobalAveragePooling2D, Activation, Concatenate, AveragePooling2D, Dropout, BatchNormalization
from .net import BaseNet
class DenseNet(BaseNet):
def __init__(self, input_shape: tuple = (176, 176, 3), classes: int = 80, **kwargs):
"""
DenseNet builder
:param input_shape: input shape
:param classes: number of output classes
"""
super().__init__(input_shape, classes, **kwargs)
self.nb_channels = self.growth_rate = None
def apply_bn_relu_conv(self, x, nb_filters: int, kernel_size: Union[int, Tuple[int, int]]):
"""
apply a BN -> ReLU -> CONV block to tensor x
Original implementation: BN -> Scale -> ReLU -> CONV
In keras, scale is actually automatically done by ReLU layer so it can be skipped
:param x: input tensor
:param nb_filters: number of filters in conv layer
:param kernel_size: kernel size of conv layer
:return: output tensor
"""
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = self.Conv2D(nb_filters, kernel_size)(x)
if self.dropout_rate:
x = Dropout(self.dropout_rate)(x)
return x
def apply_dense_block(self, x, nb_layers: int):
"""
apply a dense block to input tensor x
:param x: input tensor
:param nb_layers: number of layers in this block
:return: output tensor
"""
for _ in range(nb_layers):
conv = self.apply_bn_relu_conv(x, self.growth_rate, (3, 3))
x = Concatenate()([x, conv])
self.nb_channels += self.growth_rate
return x
def apply_transition(self, x):
"""
apply a transition layer to input tensor x
:param x: input tensor
:return: output tensor
"""
x = self.apply_bn_relu_conv(x, self.nb_channels, (1, 1))
x = AveragePooling2D()(x)
return x
def build(self,
growth_rate: int = 12,
block_config: tuple = (12, 12, 12),
nb_first_output: int = 16,
name='DenseNet') -> Model:
"""
build a densenet model
:param growth_rate: growth rate (k)
:param block_config: number of layers in each dense block
:param nb_first_output: output size of the first conv layer
:param name: model name
:return: model
"""
self.nb_channels = nb_first_output
self.growth_rate = growth_rate
img_input = Input(shape=self.input_shape)
x = self.Conv2D(nb_first_output, 3)(img_input)
# dense blocks
for i, nb_layer in enumerate(block_config):
x = self.apply_dense_block(x, nb_layer)
if i < len(block_config) - 1: # no transition in the last dense block,
x = self.apply_transition(x)
x = self.BatchNormalization()(x)
x = Activation('relu')(x)
x = GlobalAveragePooling2D()(x)
x = self.Dense(self.classes, 'softmax')(x)
return Model(inputs=img_input, outputs=x, name=name)
def build40(self):
"""
build a densenet-40 model
same as https://github.com/liuzhuang13/DenseNetCaffe
:return: model
"""
return self.build()
def build22(self):
"""
build a densenet-22 model
3 dense block, 6 conv layers in each block
:return: model
"""
return self.build(block_config=(6, 6, 6))
def build121(self):
"""
build a densenet-121 model
:return: model
"""
return self.build(nb_first_output=64, growth_rate=32, block_config=(6, 12, 24, 16))
def build169(self):
"""
build a densenet-169 model
:return: model
"""
return self.build(nb_first_output=64, growth_rate=32, block_config=(6, 12, 32, 32))
def build201(self):
"""
build a densenet-201 model
:return: model
"""
return self.build(nb_first_output=64, growth_rate=32, block_config=(6, 12, 48, 32))
def build161(self):
"""
build a densenet-161 model
:return: model
"""
return self.build(nb_first_output=96, growth_rate=48, block_config=(6, 12, 32, 32))