Skip to content

Commit

Permalink
add F1 metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
anas-rz committed Feb 28, 2024
1 parent 3d961cb commit 0cb0f41
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 2 deletions.
20 changes: 19 additions & 1 deletion k3_addons/metrics/f_scores.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from keras import ops, metrics
from k3_addons.api_export import k3_export


@k3_export(path="k3_addons.metrics.FBetaScore")
class FBetaScore(metrics.Metric):
def __init__(
self,
Expand Down Expand Up @@ -123,3 +124,20 @@ def reset_state(self):
reset_value = ops.zeros(self.init_shape, dtype=self.dtype)
for v in self.variables:
v.assign(reset_value)

@k3_export(path="k3_addons.metrics.F1Score")
class F1Score(FBetaScore):
def __init__(
self,
num_classes,
average=None,
threshold=None,
name="f1_score",
dtype=None,
):
super().__init__(num_classes, average, 1.0, threshold, name=name, dtype=dtype)

def get_config(self):
base_config = super().get_config()
del base_config["beta"]
return base_config
54 changes: 53 additions & 1 deletion k3_addons/metrics/f_scores_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from k3_addons.metrics.f_scores import FBetaScore
from k3_addons.metrics.f_scores import FBetaScore, F1Score
from k3_addons.metrics.utils import _get_model
from keras import ops, backend

Expand Down Expand Up @@ -167,3 +167,55 @@ def test_fbeta_weighted_random_score_none(avg_val, beta, sample_weights, result)
def test_keras_model():
fbeta = FBetaScore(5, "micro", 1.0)
_get_model(fbeta, 5)


def test_eq():
f1 = F1Score(3)
fbeta = FBetaScore(3, beta=1.0)

preds = [
[0.9, 0.1, 0],
[0.2, 0.6, 0.2],
[0, 0, 1],
[0.4, 0.3, 0.3],
[0, 0.9, 0.1],
[0, 0, 1],
]
actuals = [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 0, 1]]
preds = ops.convert_to_tensor(preds, "float32")
actuals = ops.convert_to_tensor(actuals, "float32")
fbeta.update_state(actuals, preds)
f1.update_state(actuals, preds)
np.testing.assert_allclose(fbeta.result().numpy(), f1.result().numpy())


def test_sample_eq():
f1 = F1Score(3)
f1_weighted = F1Score(3)

preds = ops.convert_to_tensor([
[0.9, 0.1, 0],
[0.2, 0.6, 0.2],
[0, 0, 1],
[0.4, 0.3, 0.3],
[0, 0.9, 0.1],
[0, 0, 1],
])
actuals = ops.convert_to_tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 0, 1]])
sample_weights = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]

f1.update_state(actuals, preds)
f1_weighted(actuals, preds, sample_weights)
np.testing.assert_allclose(f1.result().numpy(), f1_weighted.result().numpy())


def test_keras_model_f1():
f1 = F1Score(5)
_get_model(f1, 5)


def test_config_f1():
f1 = F1Score(3)
config = f1.get_config()
assert "beta" not in config

0 comments on commit 0cb0f41

Please sign in to comment.