Skip to content

Commit

Permalink
add losses api and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
anas-rz committed Feb 27, 2024
1 parent a1937d2 commit f4365dc
Show file tree
Hide file tree
Showing 8 changed files with 209 additions and 1 deletion.
24 changes: 24 additions & 0 deletions k3_addons/losses/contrastive_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from keras import ops
from k3_addons.utils.keras_utils import LossFunctionWrapper
from k3_addons.api_export import k3_export


def contrastive_loss(y_true, y_pred, margin=1.0):
y_pred = ops.convert_to_tensor(y_pred)
y_true = ops.cast(y_true, y_pred.dtype)
return y_true * ops.square(y_pred) + (1.0 - y_true) * ops.square(
ops.maximum(margin - y_pred, 0.0)
)


@k3_export("k3_addons.losses.ContrastiveLoss")
class ContrastiveLoss(LossFunctionWrapper):
def __init__(
self,
margin=1.0,
reduction=None,
name="contrastive_loss",
):
super().__init__(
contrastive_loss, reduction=reduction, name=name, margin=margin
)
13 changes: 13 additions & 0 deletions k3_addons/losses/contrastive_loss_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import numpy as np
from keras import ops
from k3_addons.losses.contrastive_loss import ContrastiveLoss


def test_contrastive_loss():
a = ops.convert_to_tensor([2, 3, 5], dtype="float16")
b = ops.convert_to_tensor([5, 3, 1], dtype="float16")
loss = ContrastiveLoss()(a, b)
assert ops.shape(loss) == (3,)
assert np.allclose(
loss, ops.convert_to_tensor([50.0, 27.0, 5.0], dtype="float16")
) # from tf_addons output
62 changes: 62 additions & 0 deletions k3_addons/losses/focal_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from keras import ops
from k3_addons.utils.keras_utils import LossFunctionWrapper

from k3_addons.api_export import k3_export


def sigmoid_focal_crossentropy(
y_true,
y_pred,
alpha=0.25,
gamma=2.0,
from_logits: bool = False,
):
if gamma and gamma < 0:
raise ValueError("Value of gamma should be greater than or equal to zero.")

y_pred = ops.convert_to_tensor(y_pred)
y_true = ops.cast(y_true, dtype=y_pred.dtype)

# Get the cross_entropy for each entry
ce = ops.binary_crossentropy(y_true, y_pred, from_logits=from_logits)

# If logits are provided then convert the predictions into probabilities
if from_logits:
pred_prob = ops.sigmoid(y_pred)
else:
pred_prob = y_pred

p_t = (y_true * pred_prob) + ((1 - y_true) * (1 - pred_prob))
alpha_factor = 1.0
modulating_factor = 1.0

if alpha:
alpha = ops.cast(alpha, dtype=y_true.dtype)
alpha_factor = y_true * alpha + (1 - y_true) * (1 - alpha)

if gamma:
gamma = ops.cast(gamma, dtype=y_true.dtype)
modulating_factor = ops.power((1.0 - p_t), gamma)

# compute the final loss and return
return ops.sum(alpha_factor * modulating_factor * ce, axis=-1)


@k3_export("k3_addons.losses.SigmoidFocalCrossEntropy")
class SigmoidFocalCrossEntropy(LossFunctionWrapper):
def __init__(
self,
from_logits: bool = False,
alpha=0.25,
gamma=2.0,
reduction=None,
name="sigmoid_focal_crossentropy",
):
super().__init__(
sigmoid_focal_crossentropy,
name=name,
reduction=reduction,
from_logits=from_logits,
alpha=alpha,
gamma=gamma,
)
22 changes: 22 additions & 0 deletions k3_addons/losses/focal_loss_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest
import numpy as np
from keras import ops

from k3_addons.losses.focal_loss import SigmoidFocalCrossEntropy


@pytest.mark.parametrize(
"y_true, y_pred",
[([[1.0], [1.0], [0.0]], [[0.97], [0.91], [0.03]])],
)
def test_sigmoid_focal_crossentropy(y_true, y_pred):
out_tf = ops.convert_to_tensor(
[6.8532745e-06, 1.9097870e-04, 2.0559824e-05]
) # from tensorflow_addons
# Calculate sigmoid within the test
y_pred_sigmoid = y_pred

# Use your focal loss implementation with the calculated sigmoid
loss = SigmoidFocalCrossEntropy()(y_true=y_true, y_pred=y_pred_sigmoid)

assert np.allclose(loss, out_tf)
60 changes: 60 additions & 0 deletions k3_addons/losses/giou_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from keras import ops, backend
from k3_addons.utils.keras_utils import LossFunctionWrapper
from k3_addons.api_export import k3_export


@k3_export("k3_addons.losses.GIoULoss")
class GIoULoss(LossFunctionWrapper):
def __init__(
self,
mode="giou",
reduction=None,
name="giou_loss",
):
super().__init__(giou_loss, name=name, reduction=reduction, mode=mode)


def giou_loss(y_true, y_pred, mode="giou"):
if mode not in ["giou", "iou"]:
raise ValueError("Value of mode should be 'iou' or 'giou'")
y_pred = ops.convert_to_tensor(y_pred)
if not backend.is_float_dtype(y_pred.dtype):
y_pred = ops.cast(y_pred, "float32")
y_true = ops.cast(y_true, y_pred.dtype)
giou = ops.squeeze(_calculate_giou(y_pred, y_true, mode))
return 1 - giou


def _calculate_giou(b1, b2, mode="giou"):
zero = ops.convert_to_tensor(0.0, b1.dtype)
b1_ymin, b1_xmin, b1_ymax, b1_xmax = ops.unstack(b1, 4, axis=-1)
b2_ymin, b2_xmin, b2_ymax, b2_xmax = ops.unstack(b2, 4, axis=-1)
b1_width = ops.maximum(zero, b1_xmax - b1_xmin)
b1_height = ops.maximum(zero, b1_ymax - b1_ymin)
b2_width = ops.maximum(zero, b2_xmax - b2_xmin)
b2_height = ops.maximum(zero, b2_ymax - b2_ymin)
b1_area = b1_width * b1_height
b2_area = b2_width * b2_height

intersect_ymin = ops.maximum(b1_ymin, b2_ymin)
intersect_xmin = ops.maximum(b1_xmin, b2_xmin)
intersect_ymax = ops.minimum(b1_ymax, b2_ymax)
intersect_xmax = ops.minimum(b1_xmax, b2_xmax)
intersect_width = ops.maximum(zero, intersect_xmax - intersect_xmin)
intersect_height = ops.maximum(zero, intersect_ymax - intersect_ymin)
intersect_area = intersect_width * intersect_height

union_area = b1_area + b2_area - intersect_area
iou = ops.divide_no_nan(intersect_area, union_area)
if mode == "iou":
return iou

enclose_ymin = ops.minimum(b1_ymin, b2_ymin)
enclose_xmin = ops.minimum(b1_xmin, b2_xmin)
enclose_ymax = ops.maximum(b1_ymax, b2_ymax)
enclose_xmax = ops.maximum(b1_xmax, b2_xmax)
enclose_width = ops.maximum(zero, enclose_xmax - enclose_xmin)
enclose_height = ops.maximum(zero, enclose_ymax - enclose_ymin)
enclose_area = enclose_width * enclose_height
giou = iou - ops.divide_no_nan((enclose_area - union_area), enclose_area)
return giou
16 changes: 16 additions & 0 deletions k3_addons/losses/giou_loss_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import numpy as np
from keras import ops

from k3_addons.losses.giou_loss import GIoULoss


def test_sigmoid_giou():
boxes1 = ops.convert_to_tensor([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]])
boxes2 = ops.convert_to_tensor([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0]])

out_tfa = ops.convert_to_tensor([1.075, 1.9333334])

# Use your focal loss implementation with the calculated sigmoid
loss = GIoULoss()(boxes1, boxes2)

assert np.allclose(loss, out_tfa)
11 changes: 11 additions & 0 deletions k3_addons/utils/keras_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from keras.losses import Loss


class LossFunctionWrapper(Loss):
def __init__(self, fn, reduction=None, name=None, **kwargs):
super().__init__(reduction=reduction, name=name)
self.fn = fn
self._fn_kwargs = kwargs

def call(self, y_true, y_pred):
return self.fn(y_true, y_pred, **self._fn_kwargs)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_version(rel_path):
author_email="memanasraza@gmail.com",
license="Apache License 2.0",
install_requires=[
"keras>=3.0"
"keras>=3.0.5"
],
# Supported Python versions
python_requires=">=3.9",
Expand Down

0 comments on commit f4365dc

Please sign in to comment.