Sanghyun Woo, et al. "CBAM: Convolutional Block Attention Module." arXiv preprint arXiv:1807.06521v2 (2018).
class channel_attention(tf.keras.layers.Layer):
"""
channel attention module
Contains the implementation of Convolutional Block Attention Module(CBAM) block.
As described in https://arxiv.org/abs/1807.06521.
"""
def __init__(self, ratio=8, **kwargs):
self.ratio = ratio
super(channel_attention, self).__init__(**kwargs)
def build(self, input_shape):
channel = input_shape[-1]
self.shared_layer_one = tf.keras.layers.Dense(channel // self.ratio,
activation='relu',
kernel_initializer='he_normal',
use_bias=True,
bias_initializer='zeros')
self.shared_layer_two = tf.keras.layers.Dense(channel,
kernel_initializer='he_normal',
use_bias=True,
bias_initializer='zeros')
super(channel_attention, self).build(input_shape)
def compute_output_shape(self, input_shape):
return input_shape
def call(self, inputs):
channel = inputs.get_shape().as_list()[-1]
avg_pool = tf.keras.layers.GlobalAveragePooling3D()(inputs)
avg_pool = tf.keras.layers.Reshape((1, 1, 1, channel))(avg_pool)
avg_pool = self.shared_layer_one(avg_pool)
avg_pool = self.shared_layer_two(avg_pool)
max_pool = tf.keras.layers.GlobalMaxPooling3D()(inputs)
max_pool = tf.keras.layers.Reshape((1, 1, 1, channel))(max_pool)
max_pool = self.shared_layer_one(max_pool)
max_pool = self.shared_layer_two(max_pool)
feature = tf.keras.layers.Add()([avg_pool, max_pool])
feature = tf.keras.layers.Activation('sigmoid')(feature)
return tf.keras.layers.multiply([inputs, feature])
class spatial_attention(tf.keras.layers.Layer):
""" spatial attention module
Contains the implementation of Convolutional Block Attention Module(CBAM) block.
As described in https://arxiv.org/abs/1807.06521.
"""
def __init__(self, kernel_size=7, **kwargs):
self.kernel_size = kernel_size
super(spatial_attention, self).__init__(**kwargs)
def build(self, input_shape):
self.conv3d = tf.keras.layers.Conv3D(filters=1, kernel_size=self.kernel_size,
strides=1, padding='same', activation='sigmoid',
kernel_initializer='he_normal', use_bias=False)
super(spatial_attention, self).build(input_shape)
def compute_output_shape(self, input_shape):
return input_shape
def call(self, inputs):
avg_pool = tf.keras.layers.Lambda(lambda x: tf.keras.backend.mean(x, axis=-1, keepdims=True))(inputs)
max_pool = tf.keras.layers.Lambda(lambda x: tf.keras.backend.max(x, axis=-1, keepdims=True))(inputs)
concat = tf.keras.layers.Concatenate(axis=-1)([avg_pool, max_pool])
feature = self.conv3d(concat)
return tf.keras.layers.multiply([inputs, feature])
Jun Fu, et al. "Dual attention network for scene segmentation."
Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.
class Channel_attention(tf.keras.layers.Layer):
"""
Channel attention module
Fu, Jun, et al. "Dual attention network for scene segmentation."
Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.
"""
def __init__(self,
gamma_initializer=tf.zeros_initializer(),
gamma_regularizer=None,
gamma_constraint=None,
**kwargs):
super(Channel_attention, self).__init__(**kwargs)
self.gamma_initializer = gamma_initializer
self.gamma_regularizer = gamma_regularizer
self.gamma_constraint = gamma_constraint
def build(self, input_shape):
self.gamma = self.add_weight(shape=(1,),
initializer=self.gamma_initializer,
name='gamma',
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
super(Channel_attention, self).build(input_shape)
def compute_output_shape(self, input_shape):
return input_shape
def call(self, inputs):
input_shape = inputs.get_shape().as_list()
proj_query = tf.keras.layers.Reshape((input_shape[1] * input_shape[2] * input_shape[3],
input_shape[4]))(inputs)
proj_key = tf.keras.backend.permute_dimensions(proj_query, (0, 2, 1))
energy = tf.keras.backend.batch_dot(proj_query, proj_key)
attention = tf.keras.activations.softmax(energy)
outputs = tf.keras.backend.batch_dot(attention, proj_query)
outputs = tf.keras.layers.Reshape((input_shape[1], input_shape[2], input_shape[3],
input_shape[4]))(outputs)
outputs = self.gamma * outputs + inputs
return outputs
class Position_attention(tf.keras.layers.Layer):
"""
Position attention module
Fu, Jun, et al. "Dual attention network for scene segmentation."
Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.
"""
def __init__(self,
ratio = 8,
gamma_initializer=tf.zeros_initializer(),
gamma_regularizer=None,
gamma_constraint=None,
**kwargs):
super(Position_attention, self).__init__(**kwargs)
self.ratio = ratio
self.gamma_initializer = gamma_initializer
self.gamma_regularizer = gamma_regularizer
self.gamma_constraint = gamma_constraint
def build(self, input_shape):
super(Position_attention, self).build(input_shape)
self.query_conv = tf.keras.layers.Conv3D(filters=input_shape[-1] // self.ratio,
kernel_size=(1, 1, 1), use_bias=False,
kernel_initializer='he_normal')
self.key_conv = tf.keras.layers.Conv3D(filters=input_shape[-1] // self.ratio,
kernel_size=(1, 1, 1), use_bias=False,
kernel_initializer='he_normal')
self.value_conv = tf.keras.layers.Conv3D(filters=input_shape[-1], kernel_size=(1, 1, 1),
use_bias=False, kernel_initializer='he_normal')
self.gamma = self.add_weight(shape=(1,),
initializer=self.gamma_initializer,
name='gamma',
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
def compute_output_shape(self, input_shape):
return input_shape
def call(self, inputs):
input_shape = inputs.get_shape().as_list()
proj_query = tf.keras.layers.Reshape((input_shape[1] * input_shape[2] * input_shape[3],
input_shape[4] // self.ratio))(self.query_conv(inputs))
proj_query = tf.keras.backend.permute_dimensions(proj_query, (0, 2, 1))
proj_key = tf.keras.layers.Reshape((input_shape[1] * input_shape[2] * input_shape[3],
input_shape[4] // self.ratio))(self.key_conv(inputs))
energy = tf.keras.backend.batch_dot(proj_key, proj_query)
attention = tf.keras.activations.softmax(energy)
proj_value = tf.keras.layers.Reshape((input_shape[1] * input_shape[2] * input_shape[3],
input_shape[4]))(self.value_conv(inputs))
outputs = tf.keras.backend.batch_dot(attention, proj_value)
outputs = tf.keras.layers.Reshape((input_shape[1], input_shape[2], input_shape[3],
input_shape[4]))(outputs)
outputs = self.gamma * outputs + inputs
return outputs