Skip to content

Commit

Permalink
initial docs and bam
Browse files Browse the repository at this point in the history
  • Loading branch information
anas-rz committed Feb 24, 2024
1 parent d118323 commit 009e54d
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 0 deletions.
2 changes: 2 additions & 0 deletions k3_addons/layers/attention/a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

@k3_export(path="k3_addons.layers.DoubleAttention")
class DoubleAttention(layers.Layer):
"""A2-Nets: Double Attention Networks [https://arxiv.org/pdf/1810.11579.pdf]"""

def __init__(self, dim, value_dim=None, reconstruct=True):
super().__init__()
self.dim = dim
Expand Down
2 changes: 2 additions & 0 deletions k3_addons/layers/attention/aft.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

@k3_export(path="k3_addons.layers.AFTFull")
class AFTFull(layers.Layer):
"""An Attention Free Transformer [https://arxiv.org/pdf/2105.14103v1.pdf]"""

def __init__(self, projection_dim, position_bias=False):
super(AFTFull, self).__init__()
self.position_bias = position_bias
Expand Down
95 changes: 95 additions & 0 deletions k3_addons/layers/attention/bam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from keras import layers, ops, Sequential, backend
from k3_addons.layers.pooling.adaptive_pooling import AdaptiveAveragePool2D
from k3_addons.api_export import k3_export


class ChannelAttention(layers.Layer):
def __init__(self, reduction=16, num_layers=3):
super().__init__()
self.avgpool = AdaptiveAveragePool2D(1)
self.reduction = reduction
self.num_layers = num_layers

def build(self, input_shape):
input_dim = input_shape[-1]
gate_dims = [input_dim]
gate_dims += [input_dim // self.reduction] * self.num_layers
gate_dims += [input_dim]

self.channel_attention = Sequential()
self.channel_attention.add(layers.Flatten())
for i in range(len(gate_dims) - 2):
self.channel_attention.add(layers.Dense(gate_dims[i + 1]))
self.channel_attention.add(layers.BatchNormalization())
self.channel_attention.add(layers.Activation("relu"))
self.channel_attention.add(layers.Dense(gate_dims[-1]))

def call(self, x):
if backend.image_data_format() == "channels_last":
start_axis = 1
else:
start_axis = 2
res = self.avgpool(x) # b 1 1 c
res = self.channel_attention(res) # b c
res = ops.expand_dims(res, axis=start_axis) # b 1 c
res = ops.expand_dims(res, axis=start_axis + 1) # b 1 1 c
res = ops.broadcast_to(res, ops.shape(x))
return res


class SpatialAttention(layers.Layer):
def __init__(self, reduction=16, num_layers=3, dilation_rate=2):
super().__init__()

self.reduction = reduction
self.num_layers = num_layers
self.dilation_rate = dilation_rate

def build(self, input_shape):
if backend.image_data_format() == "channels_last":
input_dims = input_shape[-1]
else:
input_dims = input_shape[1]
self.spatial_attention = Sequential()
self.spatial_attention.add(
layers.Conv2D(input_dims // self.reduction, kernel_size=1)
)
self.spatial_attention.add(layers.BatchNormalization())
self.spatial_attention.add(layers.Activation("relu"))
for i in range(self.num_layers):
self.spatial_attention.add(layers.ZeroPadding2D(padding=1))
self.spatial_attention.add(
layers.Conv2D(
input_dims // self.reduction,
kernel_size=3,
dilation_rate=self.dilation_rate,
)
)
self.spatial_attention.add(layers.BatchNormalization())
self.spatial_attention.add(layers.Activation("relu"))
self.spatial_attention.add(layers.Conv2D(1, kernel_size=1))

def call(self, x):
res = self.spatial_attention(x)
res = ops.broadcast_to(res, ops.shape(x))
return res

@k3_export(path='k3_addons.layers.BAMBlock')
class BAMBlock(layers.Layer):
"""
BAM: Bottleneck Attention Module [https://arxiv.org/pdf/1807.06514.pdf]
"""
def __init__(self, reduction=16, dilation_rate=2):
super().__init__()
self.channel_attention = ChannelAttention(reduction=reduction)
self.spatial_attention = SpatialAttention(
reduction=reduction, dilation_rate=dilation_rate
)

def call(self, x):
sa_out = self.channel_attention(x)
ca_out = self.spatial_attention(x)
weight = ops.sigmoid(sa_out + ca_out)
out = (1 + weight) * x
return out
13 changes: 13 additions & 0 deletions k3_addons/layers/attention/bam_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest
import keras
from keras import ops

from k3_addons.layers.attention.bam import BAMBlock


@pytest.mark.parametrize("input_shape", [(1, 7, 7, 512), (1, 7, 7, 128)])
def test_bam(input_shape):
inputs = keras.random.uniform((input_shape))
layer = BAMBlock(reduction=8)
outputs = layer(inputs)
assert ops.shape(outputs) == input_shape
4 changes: 4 additions & 0 deletions k3_addons/layers/attention/cbam.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def call(self, x):

@k3_export("k3_addons.layers.CBAM")
class CBAMBlock(layers.Layer):
"""
CBAM: Convolutional Block Attention Module [https://openaccess.thecvf.com/content_ECCV_2018/papers/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.pdf]
"""

def __init__(self, reduction=16, kernel_size=49):
super().__init__()
self.channel_attention = ChannelAttention(reduction=reduction)
Expand Down
4 changes: 4 additions & 0 deletions k3_addons/layers/attention/eca.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

@k3_export("k3_addons.layers.ECAAttention")
class ECAAttention(layers.Layer):
"""
ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks [https://arxiv.org/pdf/1910.03151.pdf]
"""

def __init__(self, kernel_size=3):
super().__init__()
self.pooling = AdaptiveAveragePool2D(1)
Expand Down
4 changes: 4 additions & 0 deletions k3_addons/layers/attention/external_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

@k3_export("k3_addons.layers.ExternalAttention")
class ExternalAttention(layers.Layer):
"""
Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks [https://arxiv.org/abs/2105.02358]
"""

def __init__(self, intermediate_dim=64):
super().__init__()
self.intermediate_dim = intermediate_dim
Expand Down
2 changes: 2 additions & 0 deletions k3_addons/layers/attention/residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

@k3_export(path="k3_addons.layers.ResidualAttention")
class ResidualAttention(layers.Layer):
"""Residual Attention: A Simple but Effective Method for Multi-Label Recognition [https://arxiv.org/abs/2108.02456]"""

def __init__(self, num_class=1000, alpha=0.2):
super().__init__()
self.alpha = alpha
Expand Down
13 changes: 13 additions & 0 deletions k3_addons/layers/pooling/adaptive_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def compute_output_shape(self, input_shape):

@k3_export(path="k3_addons.layers.AdaptiveMaxPool1D")
class AdaptiveMaxPool1D(BaseAdaptivePool):
"""
Adaptive Pooling like torch.nn.AdaptiveMaxPool1d
"""

def __init__(self, output_size, data_format=None, padding="valid", **kwargs):
super(AdaptiveMaxPool1D, self).__init__(
output_size,
Expand All @@ -108,6 +112,9 @@ def __init__(self, output_size, data_format=None, padding="valid", **kwargs):

@k3_export(path="k3_addons.layers.AdaptiveAveragePool1D")
class AdaptiveAveragePool1D(BaseAdaptivePool):
""" Adaptive Pooling like torch.nn.AdaptiveAvgPool1d
"""
def __init__(self, output_size, data_format=None, padding="valid", **kwargs):
super(AdaptiveAveragePool1D, self).__init__(
output_size,
Expand All @@ -121,6 +128,9 @@ def __init__(self, output_size, data_format=None, padding="valid", **kwargs):

@k3_export(path="k3_addons.layers.AdaptiveMaxPool2D")
class AdaptiveMaxPool2D(BaseAdaptivePool):
""" Adaptive Pooling like torch.nn.AdaptiveMaxPool2d
"""
def __init__(self, output_size, data_format=None, padding="valid", **kwargs):
super(AdaptiveMaxPool2D, self).__init__(
output_size,
Expand All @@ -134,6 +144,9 @@ def __init__(self, output_size, data_format=None, padding="valid", **kwargs):

@k3_export(path="k3_addons.layers.AdaptiveAveragePool2D")
class AdaptiveAveragePool2D(BaseAdaptivePool):
""" Adaptive Pooling like torch.nn.AdaptiveAvgPool2d
"""
def __init__(self, output_size, data_format=None, padding="valid", **kwargs):
super(AdaptiveAveragePool2D, self).__init__(
output_size,
Expand Down

0 comments on commit 009e54d

Please sign in to comment.