Skip to content

Commit

Permalink
nits
Browse files Browse the repository at this point in the history
  • Loading branch information
anas-rz committed Feb 24, 2024
1 parent 6af1b8c commit b25f4ab
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 1 deletion.
Binary file added .assets/k-addons.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
__pycache__
.pytest_cache
.pytest_cache
dist/
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# k3-addons: Additional multi-backend functionality for Keras 3.
![Logo](.assets/k3-addons.png)

# Includes:
- Layers
- Pooling:
- `AdaptiveAveragePooling1D`
- `AdaptiveMaxPooling1D`
- `AdaptiveAveragePooling2D`
- `AdaptiveMaxPooling2D`
- Attention:
- `DoubleAttention`
- `AFTFull`
- `ChannelAttention2D`
- `SpatialAttention2D`
- `ECAAttention`
- `ExternalAttention`
- `ResidualAttention`
68 changes: 68 additions & 0 deletions k3_addons/layers/attention/cbam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from keras import layers, ops, Sequential, backend

from k3_addons.layers.pooling.adaptive_pooling import (
AdaptiveMaxPool2D,
AdaptiveAveragePool2D,
)
from k3_addons.api_export import k3_export

@k3_export('k3_addons.layers.ChannelAttention2D')
class ChannelAttention(layers.Layer):
def __init__(self, reduction=16):
super().__init__()
self.reduction = reduction
self.maxpool = AdaptiveMaxPool2D(1)
self.avgpool = AdaptiveAveragePool2D(1)
self.data_format = backend.image_data_format()

def build(self, input_shape):
if self.data_format == "channels_last":
input_dim = input_shape[3]
else:
input_dim = input_shape[1]
self.se = Sequential(
[
layers.Conv2D(input_dim // self.reduction, 1, use_bias=False),
layers.Activation("relu"),
layers.Conv2D(input_dim, 1, use_bias=False),
]
)

def call(self, x):
max_result = self.maxpool(x)
avg_result = self.avgpool(x)
max_out = self.se(max_result)
avg_out = self.se(avg_result)
output = ops.sigmoid(max_out + avg_out)
return output

@k3_export('k3_addons.layers.SpatialAttention2D')
class SpatialAttention(layers.Layer):
def __init__(self, kernel_size=7):
super().__init__()
self.conv = layers.Conv2D(1, kernel_size=kernel_size, padding='same')

def call(self, x):
if backend.image_data_format() == "channels_first":
axis = 1
else:
axis = 3
max_result = ops.max(x, axis=axis, keepdims=True)
avg_result = ops.mean(x, axis=axis, keepdims=True)
result = ops.concatenate([max_result, avg_result], axis)
output = self.conv(result)
output = ops.sigmoid(output)
return output

@k3_export('k3_addons.layers.CBAM')
class CBAMBlock(layers.Layer):
def __init__(self, reduction=16, kernel_size=49):
super().__init__()
self.channel_attention = ChannelAttention(reduction=reduction)
self.spatial_attention = SpatialAttention(kernel_size=kernel_size)

def call(self, x):
residual = x
out = x * self.channel_attention(x)
out = out * self.spatial_attention(out)
return out + residual
23 changes: 23 additions & 0 deletions k3_addons/layers/attention/eca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from keras import layers, ops, Sequential
from k3_addons.layers.pooling.adaptive_pooling import AdaptiveAveragePool2D
from k3_addons.api_export import k3_export


@k3_export('k3_addons.layers.ECAAttention')
class ECAAttention(layers.Layer):

def __init__(self, kernel_size=3):
super().__init__()
self.pooling = AdaptiveAveragePool2D(1)
self.conv=layers.Conv1D(1, kernel_size=kernel_size, padding='same')


def call(self, x):
y = self.pooling(x) #b,1, 1, c
y = ops.squeeze(y, axis=2) # b, 1, c
y = ops.transpose(y, axes=[0, 2, 1]) # b, c, 1
y = self.conv(y) #b, c, 1
y = ops.sigmoid(y) # b, c, 1
y = ops.transpose(y, axes=[0, 2, 1]) # b, 1, c
y= ops.expand_dims(y, axis=2) #bs, 1, 1, c
return x * ops.broadcast_to(y, ops.shape(x))

0 comments on commit b25f4ab

Please sign in to comment.