Skip to content

Commit

Permalink
more layers..
Browse files Browse the repository at this point in the history
  • Loading branch information
anas-rz committed Feb 27, 2024
1 parent bf589de commit a1937d2
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 0 deletions.
1 change: 1 addition & 0 deletions k3_addons/layers/attention/simam.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from keras import layers, ops
from k3_addons.api_export import k3_export


@k3_export(path="k3_addons.layers.SimAM")
class SimAM(layers.Layer):
def __init__(self, e_lambda=1e-4, activation="sigmoid"):
Expand Down
38 changes: 38 additions & 0 deletions k3_addons/layers/pooling/maxout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from keras import layers, ops
from k3_addons.api_export import k3_export


@k3_export("k3_addons.layers.Maxout")
class Maxout(layers.Layer):
def __init__(self, num_units: int, axis: int = -1, **kwargs):
super().__init__(**kwargs)
self.num_units = num_units
self.axis = axis

def call(self, inputs):
shape = list(ops.shape(inputs))
# Dealing with batches with arbitrary sizes
for i in range(len(shape)):
if shape[i] is None:
shape[i] = ops.shape(inputs)[i]

num_channels = shape[self.axis]
if num_channels % self.num_units:
raise ValueError(
"number of features({}) is not "
"a multiple of num_units({})".format(num_channels, self.num_units)
)

if self.axis < 0:
axis = self.axis + len(shape)
else:
axis = self.axis
assert axis >= 0, "Find invalid axis: {}".format(self.axis)

expand_shape = shape[:]
expand_shape[axis] = self.num_units
k = num_channels // self.num_units
expand_shape.insert(axis, k)

outputs = ops.max(ops.reshape(inputs, expand_shape), axis, keepdims=False)
return outputs
25 changes: 25 additions & 0 deletions k3_addons/layers/pooling/maxout_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest
import keras
from keras import ops
from k3_addons.layers.pooling.maxout import Maxout


@pytest.mark.parametrize(
"num_units, axis",
[
(8, -1),
(4, -1),
(16, -1),
(8, -2),
],
)
def test_maxout_output_shape(num_units, axis):
input_shape = (1, 224, 224, 32)
inputs = keras.random.uniform(input_shape)
out = Maxout(num_units, axis=axis)(inputs)

# Construct the expected output shape
expected_shape = list(input_shape)
expected_shape[axis] = num_units

assert ops.shape(out) == tuple(expected_shape)

0 comments on commit a1937d2

Please sign in to comment.