diff --git a/README.md b/README.md index b63a14a..52b8e55 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,21 @@ # k3-addons: Additional multi-backend functionality for Keras 3. ![Logo](.assets/k-addons.png) +K3 Addons supercharge your multibackend Keras 3 workflow, giving access to various innovative machine learning techniques. While Keras 3 offers a rich set of APIs, not everything can be included in the core APIs due to less generic usage. K3 Addons bridges this gap, ensuring you're not limited by the core Keras 3 library. These add-ons might include various attention mechanisms for Text and Image Data, advanced optimizers, or specialized layers tailored for unique data types. With K3 Addons, you'll gain the flexibility to tackle emerging ML challenges and push the boundaries of what's possible with Keras 3. + # Installation +To Install K3 Addons simply run following command in your environment: ```bash pip install k3-addons ``` # Includes: + +Currently includes `layers`, `losses`, and `activations` API. + - ## Layers + - ### Pooling: - `k3_addons.layers.AdaptiveAveragePooling1D` diff --git a/k3_addons/layers/attention/a2a.py b/k3_addons/layers/attention/a2a.py index 20c759b..aff99b9 100644 --- a/k3_addons/layers/attention/a2a.py +++ b/k3_addons/layers/attention/a2a.py @@ -7,8 +7,8 @@ class DoubleAttention(layers.Layer): """A2-Nets: Double Attention Networks [https://arxiv.org/pdf/1810.11579.pdf]""" - def __init__(self, dim, value_dim=None, reconstruct=True): - super().__init__() + def __init__(self, dim, value_dim=None, reconstruct=True, **kwargs): + super().__init__(**kwargs) self.dim = dim self.value_dim = value_dim or dim self.reconstruct = reconstruct @@ -38,3 +38,14 @@ def call(self, x): if self.reconstruct: out = self.conv_reconstruct(out) # b,h,w,c return out + + def get_config(self): + config = super().get_config() + config.update( + { + "dim": self.dim, + "value_dim": self.value_dim, + "reconstruct": self.reconstruct, + } + ) + return config diff --git a/k3_addons/layers/attention/aft.py b/k3_addons/layers/attention/aft.py index a19433c..32a82cc 100644 --- a/k3_addons/layers/attention/aft.py +++ b/k3_addons/layers/attention/aft.py @@ -6,8 +6,8 @@ class AFTFull(layers.Layer): """An Attention Free Transformer [https://arxiv.org/pdf/2105.14103v1.pdf]""" - def __init__(self, projection_dim, position_bias=False): - super(AFTFull, self).__init__() + def __init__(self, projection_dim, position_bias=False, **kwargs): + super().__init__(**kwargs) self.position_bias = position_bias self.projection_dim = projection_dim self.to_q = layers.Dense(projection_dim) @@ -36,3 +36,13 @@ def call(self, inputs): out = numerator / denominator # n,bs,dim out = ops.sigmoid(q) * (ops.transpose(out, (1, 0, 2))) # bs,n,dim return out + + def get_config(self): + config = super().get_config() + config.update( + { + "projection_dim": self.projection_dim, + "position_bias": self.position_bias, + } + ) + return config diff --git a/k3_addons/layers/attention/bam.py b/k3_addons/layers/attention/bam.py index 0d1bceb..00a71a3 100644 --- a/k3_addons/layers/attention/bam.py +++ b/k3_addons/layers/attention/bam.py @@ -4,8 +4,8 @@ class ChannelAttention(layers.Layer): - def __init__(self, reduction=16, num_layers=3): - super().__init__() + def __init__(self, reduction=16, num_layers=3, **kwargs): + super().__init__(**kwargs) self.avgpool = AdaptiveAveragePool2D(1) self.reduction = reduction self.num_layers = num_layers @@ -36,10 +36,20 @@ def call(self, x): res = ops.broadcast_to(res, ops.shape(x)) return res + def get_config(self): + config = super().get_config() + config.update( + { + "reduction": self.reduction, + "num_layers": self.num_layers, + } + ) + return config + class SpatialAttention(layers.Layer): - def __init__(self, reduction=16, num_layers=3, dilation_rate=2): - super().__init__() + def __init__(self, reduction=16, num_layers=3, dilation_rate=2, **kwargs): + super().__init__(**kwargs) self.reduction = reduction self.num_layers = num_layers @@ -74,6 +84,17 @@ def call(self, x): res = ops.broadcast_to(res, ops.shape(x)) return res + def get_config(self): + config = super().get_config() + config.update( + { + "reduction": self.reduction, + "num_layers": self.num_layers, + "dilation_rate": self.dilation_rate, + } + ) + return config + @k3_export(path="k3_addons.layers.BAMBlock") class BAMBlock(layers.Layer): @@ -82,8 +103,8 @@ class BAMBlock(layers.Layer): """ - def __init__(self, reduction=16, dilation_rate=2): - super().__init__() + def __init__(self, reduction=16, dilation_rate=2, **kwargs): + super().__init__(**kwargs) self.channel_attention = ChannelAttention(reduction=reduction) self.spatial_attention = SpatialAttention( reduction=reduction, dilation_rate=dilation_rate diff --git a/k3_addons/layers/attention/cbam.py b/k3_addons/layers/attention/cbam.py index c3d33a1..8a520d6 100644 --- a/k3_addons/layers/attention/cbam.py +++ b/k3_addons/layers/attention/cbam.py @@ -9,8 +9,8 @@ @k3_export("k3_addons.layers.ChannelAttention2D") class ChannelAttention(layers.Layer): - def __init__(self, reduction=16): - super().__init__() + def __init__(self, reduction=16, **kwargs): + super().__init__(**kwargs) self.reduction = reduction self.maxpool = AdaptiveMaxPool2D(1) self.avgpool = AdaptiveAveragePool2D(1) @@ -37,11 +37,20 @@ def call(self, x): output = ops.sigmoid(max_out + avg_out) return output + def get_config(self): + config = super().get_config() + config.update( + { + "reduction": self.reduction, + } + ) + return config + @k3_export("k3_addons.layers.SpatialAttention2D") class SpatialAttention(layers.Layer): - def __init__(self, kernel_size=7): - super().__init__() + def __init__(self, kernel_size=7, **kwargs): + super().__init__(**kwargs) self.conv = layers.Conv2D(1, kernel_size=kernel_size, padding="same") def call(self, x): @@ -63,8 +72,8 @@ class CBAMBlock(layers.Layer): CBAM: Convolutional Block Attention Module [https://openaccess.thecvf.com/content_ECCV_2018/papers/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.pdf] """ - def __init__(self, reduction=16, kernel_size=49): - super().__init__() + def __init__(self, reduction=16, kernel_size=49, **kwargs): + super().__init__(**kwargs) self.channel_attention = ChannelAttention(reduction=reduction) self.spatial_attention = SpatialAttention(kernel_size=kernel_size) diff --git a/k3_addons/layers/attention/eca.py b/k3_addons/layers/attention/eca.py index d8aecd8..be2a87e 100644 --- a/k3_addons/layers/attention/eca.py +++ b/k3_addons/layers/attention/eca.py @@ -9,8 +9,8 @@ class ECAAttention(layers.Layer): ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks [https://arxiv.org/pdf/1910.03151.pdf] """ - def __init__(self, kernel_size=3): - super().__init__() + def __init__(self, kernel_size=3, **kwargs): + super().__init__(**kwargs) self.pooling = AdaptiveAveragePool2D(1) self.conv = layers.Conv1D(1, kernel_size=kernel_size, padding="same") diff --git a/k3_addons/layers/attention/external_attention.py b/k3_addons/layers/attention/external_attention.py index 68b964b..e9ad2b1 100644 --- a/k3_addons/layers/attention/external_attention.py +++ b/k3_addons/layers/attention/external_attention.py @@ -8,8 +8,8 @@ class ExternalAttention(layers.Layer): Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks [https://arxiv.org/abs/2105.02358] """ - def __init__(self, intermediate_dim=64): - super().__init__() + def __init__(self, intermediate_dim=64, **kwargs): + super().__init__(**kwargs) self.intermediate_dim = intermediate_dim def build(self, input_shape): @@ -24,3 +24,12 @@ def call(self, queries): attn = attn / ops.sum(attn, axis=2, keepdims=True) out = self.mv(attn) return out + + def get_config(self): + config = super().get_config() + config.update( + { + "intermediate_dim": self.intermediate_dim, + } + ) + return config diff --git a/k3_addons/layers/attention/mobilevit.py b/k3_addons/layers/attention/mobilevit.py index abc9b65..4c782fe 100644 --- a/k3_addons/layers/attention/mobilevit.py +++ b/k3_addons/layers/attention/mobilevit.py @@ -91,8 +91,9 @@ def __init__( heads=8, head_dim=64, mlp_dim=1024, + **kwargs, ): - super().__init__() + super().__init__(**kwargs) self.ph, self.pw = patch_size, patch_size self.dim = dim self.kernel_size = kernel_size @@ -140,3 +141,20 @@ def call(self, x): x = self.conv4(x) return x + + def get_config(self): + config = super().get_config() + config.update( + { + "ph": self.ph, + "pw": self.pw, + "dim": self.dim, + "kernel_size": self.kernel_size, + "patch_size": self.patch_size, + "depth": self.depth, + "heads": self.heads, + "head_dim": self.head_dim, + "mlp_dim": self.mlp_dim, + } + ) + return config diff --git a/k3_addons/layers/attention/mobilevit_v2.py b/k3_addons/layers/attention/mobilevit_v2.py index 525d6b6..e52b73a 100644 --- a/k3_addons/layers/attention/mobilevit_v2.py +++ b/k3_addons/layers/attention/mobilevit_v2.py @@ -4,8 +4,8 @@ @k3_export(path="k3_addons.layers.MobileViTv2Attention") class MobileViTv2Attention(layers.Layer): - def __init__(self): - super().__init__() + def __init__(self, **kwargs): + super().__init__(**kwargs) def build(self, input_shape): projection_dim = input_shape[-1] diff --git a/k3_addons/layers/attention/parnet.py b/k3_addons/layers/attention/parnet.py index 68e40d8..0b782e6 100644 --- a/k3_addons/layers/attention/parnet.py +++ b/k3_addons/layers/attention/parnet.py @@ -5,8 +5,8 @@ @k3_export(path="k3_addons.layers.ParNetAttention") class ParNetAttention(layers.Layer): - def __init__(self, activation="selu"): - super().__init__() + def __init__(self, activation="selu", **kwargs): + super().__init__(**kwargs) self.activation = activation def build(self, input_shape): @@ -35,3 +35,12 @@ def call(self, x): x3 = self.sse(x) * x out = self.activation(x1 + x2 + x3) return out + + def get_config(self): + config = super().get_config() + config.update( + { + "activation": self.activation, + } + ) + return config diff --git a/k3_addons/layers/attention/residual.py b/k3_addons/layers/attention/residual.py index 8ccbaf9..d088f66 100644 --- a/k3_addons/layers/attention/residual.py +++ b/k3_addons/layers/attention/residual.py @@ -6,8 +6,8 @@ class ResidualAttention(layers.Layer): """Residual Attention: A Simple but Effective Method for Multi-Label Recognition [https://arxiv.org/abs/2108.02456]""" - def __init__(self, num_class=1000, alpha=0.2): - super().__init__() + def __init__(self, num_class=1000, alpha=0.2, **kwargs): + super().__init__(**kwargs) self.alpha = alpha self.num_class = num_class self.fc = layers.Conv2D( @@ -24,3 +24,13 @@ def call(self, x): out = x_avg + self.alpha * x_max return out + + def get_config(self): + config = super().get_config() + config.update( + { + "num_class": self.num_class, + "alpha": self.alpha, + } + ) + return config diff --git a/k3_addons/layers/attention/se.py b/k3_addons/layers/attention/se.py index 05d9846..261a822 100644 --- a/k3_addons/layers/attention/se.py +++ b/k3_addons/layers/attention/se.py @@ -3,8 +3,8 @@ class SEAttention(layers.Layer): - def __init__(self, reduction=16): - super().__init__() + def __init__(self, reduction=16, **kwargs): + super().__init__(**kwargs) self.reduction = reduction def build(self, input_shape): @@ -28,3 +28,12 @@ def call(self, x): x = ops.expand_dims(x, axis=1) x = ops.expand_dims(x, axis=2) return x_skip * ops.broadcast_to(x, ops.shape(x_skip)) + + def get_config(self): + config = super().get_config() + config.update( + { + "reduction": self.reduction, + } + ) + return config diff --git a/k3_addons/layers/attention/simam.py b/k3_addons/layers/attention/simam.py index ba84f06..0ec2d6e 100644 --- a/k3_addons/layers/attention/simam.py +++ b/k3_addons/layers/attention/simam.py @@ -4,10 +4,10 @@ @k3_export(path="k3_addons.layers.SimAM") class SimAM(layers.Layer): - def __init__(self, e_lambda=1e-4, activation="sigmoid"): - super().__init__() + def __init__(self, e_lambda=1e-4, activation="sigmoid", **kwargs): + super().__init__(**kwargs) self.e_lambda = e_lambda - self.activaton = layers.Activation(activation) + self.activation = layers.Activation(activation) def call(self, x): b, h, w, c = ops.shape(x) @@ -18,4 +18,13 @@ def call(self, x): denom = ops.sum(x_minus_mu_square, axis=1, keepdims=True) denom = ops.sum(denom, axis=2, keepdims=True) / n weights = x_minus_mu_square / (4 * (denom + self.e_lambda)) + 0.5 - return x * self.activaton(weights) + return x * self.activation(weights) + + def get_config(self): + config = super().get_config() + config.update( + { + "e_lambda": self.e_lambda, + } + ) + return config diff --git a/k3_addons/layers/pooling/adaptive_pooling.py b/k3_addons/layers/pooling/adaptive_pooling.py index 94701cb..149f246 100644 --- a/k3_addons/layers/pooling/adaptive_pooling.py +++ b/k3_addons/layers/pooling/adaptive_pooling.py @@ -92,6 +92,19 @@ def compute_output_shape(self, input_shape): self.data_format, ) + def get_config(self): + config = super().get_config() + config.update( + { + "output_size": self.output_size, + "pool_dimensions": self.pool_dimensions, + "data_format": self.data_format, + "padding": self.padding, + "pool_mode": self.pool_mode, + } + ) + return config + @k3_export(path="k3_addons.layers.AdaptiveMaxPool1D") class AdaptiveMaxPool1D(BaseAdaptivePool): diff --git a/k3_addons/layers/pooling/maxout.py b/k3_addons/layers/pooling/maxout.py index 6a2c7ec..c794d0d 100644 --- a/k3_addons/layers/pooling/maxout.py +++ b/k3_addons/layers/pooling/maxout.py @@ -36,3 +36,13 @@ def call(self, inputs): outputs = ops.max(ops.reshape(inputs, expand_shape), axis, keepdims=False) return outputs + + def get_config(self): + config = super().get_config() + config.update( + { + "num_units": self.num_units, + "axis": self.axis, + } + ) + return config diff --git a/pip_build.py b/pip_build.py index 0687cfb..859f648 100644 --- a/pip_build.py +++ b/pip_build.py @@ -1,20 +1,3 @@ -"""Script to create (and optionally install) a `.whl` archive for Keras 3. - -Usage: - -1. Create a `.whl` file in `dist/`: - -``` -python3 pip_build.py -``` - -2. Also install the new package immediately after: - -``` -python3 pip_build.py --install -``` -""" - import argparse import datetime import glob @@ -24,8 +7,6 @@ import namex -# Needed because importing torch after TF causes the runtime to crash - package = "k3_addons" build_directory = "tmp_build_dir" dist_directory = "dist"