Skip to content

Commit

Permalink
add test for cbam
Browse files Browse the repository at this point in the history
  • Loading branch information
anas-rz committed Feb 24, 2024
1 parent f1a7197 commit d0338ea
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# k3-addons: Additional multi-backend functionality for Keras 3.
![Logo](.assets/k-addons.png)

# Installation

`pip install k3-addons`

# Includes:
- Layers
- Pooling:
Expand Down
28 changes: 28 additions & 0 deletions k3_addons/layers/attention/cbam_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest
import keras
from keras import ops


from k3_addons.layers.attention.cbam import ChannelAttention, SpatialAttention, CBAMBlock


@pytest.mark.parametrize("input_shape", [(1, 10, 10, 256), (1, 14, 14, 128)])
def test_channel_attention(input_shape):
inputs = keras.random.normal(input_shape)
layer = ChannelAttention()
out = layer(inputs)
assert ops.shape(out) == (1, 1, 1,) + (input_shape[-1],)

@pytest.mark.parametrize("input_shape", [(1, 10, 10, 256), (1, 14, 14, 128)])
def test_spatial_attention(input_shape):
inputs = keras.random.normal(input_shape)
layer = SpatialAttention()
out = layer(inputs)
assert ops.shape(out) == input_shape[:-1] + (1,) # Dynamic assertion

@pytest.mark.parametrize("input_shape", [(1, 10, 10, 256), (1, 14, 14, 128)])
def test_cbam(input_shape):
inputs = keras.random.normal(input_shape) # Modify input shape
layer = CBAMBlock()
out = layer(inputs)
assert ops.shape(out) == input_shape # Output shape should remain the same

0 comments on commit d0338ea

Please sign in to comment.