Skip to content

Commit

Permalink
lint + quantiles
Browse files Browse the repository at this point in the history
  • Loading branch information
anas-rz committed Feb 27, 2024
1 parent f484d13 commit e448b64
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 9 deletions.
11 changes: 5 additions & 6 deletions k3_addons/losses/kappa_loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from keras.losses import Loss
from keras import ops, backend
from k3_addons.api_export import k3_export
Expand All @@ -9,10 +8,10 @@ class WeightedKappaLoss(Loss):
def __init__(
self,
num_classes,
weightage = "quadratic",
name = "cohen_kappa_loss",
epsilon = 1e-6,
reduction = None,
weightage="quadratic",
name="cohen_kappa_loss",
epsilon=1e-6,
reduction=None,
):
super().__init__(name=name, reduction=reduction)

Expand Down Expand Up @@ -59,4 +58,4 @@ def get_config(self):
"epsilon": self.epsilon,
}
base_config = super().get_config()
return {**base_config, **config}
return {**base_config, **config}
17 changes: 14 additions & 3 deletions k3_addons/losses/kappa_loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,20 @@
from k3_addons.losses.kappa_loss import WeightedKappaLoss

import numpy as np


def test_kappa_loss():
y_true = ops.convert_to_tensor([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]])
y_pred = ops.convert_to_tensor([[0.1, 0.2, 0.6, 0.1], [0.1, 0.5, 0.3, 0.1],[0.8, 0.05, 0.05, 0.1], [0.01, 0.09, 0.1, 0.8]])
y_true = ops.convert_to_tensor(
[[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]]
)
y_pred = ops.convert_to_tensor(
[
[0.1, 0.2, 0.6, 0.1],
[0.1, 0.5, 0.3, 0.1],
[0.8, 0.05, 0.05, 0.1],
[0.01, 0.09, 0.1, 0.8],
]
)
kappa_loss = WeightedKappaLoss(num_classes=4)
loss = kappa_loss(y_true, y_pred)
np.allclose(ops.convert_to_numpy(loss), -1.1611925)
np.allclose(ops.convert_to_numpy(loss), -1.1611925)
27 changes: 27 additions & 0 deletions k3_addons/losses/quantiles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from keras import ops
from k3_addons.utils.keras_utils import LossFunctionWrapper
from k3_addons.api_export import k3_export


@k3_export("k3_addons.losses.pinball_loss")
def pinball_loss(y_true, y_pred, tau=0.5):
y_pred = ops.convert_to_tensor(y_pred)
y_true = ops.cast(y_true, y_pred.dtype)

tau = ops.expand_dims(ops.cast(tau, y_pred.dtype), 0)
one = ops.cast(1, tau.dtype)

delta_y = y_true - y_pred
pinball = ops.maximum(tau * delta_y, (tau - one) * delta_y)
return ops.mean(pinball, axis=-1)


@k3_export("k3_addons.losses.PinballLoss")
class PinballLoss(LossFunctionWrapper):
def __init__(
self,
tau=0.5,
reduction=None,
name="pinball_loss",
):
super().__init__(pinball_loss, reduction=reduction, name=name, tau=tau)
8 changes: 8 additions & 0 deletions k3_addons/losses/quantiles_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import numpy as np
from keras import ops
from k3_addons.losses.quantiles import PinballLoss


def test_pinball_loss():
loss = PinballLoss(tau=.1)([0., 0., 1., 1.],[1., 1., 1., 0.])
assert np.allclose(ops.convert_to_numpy(loss), 0.475)ß

0 comments on commit e448b64

Please sign in to comment.