Skip to content

Commit

Permalink
Merge pull request #75 from ayrna/to_array_conversion_in_metrics
Browse files Browse the repository at this point in the history
[ENH] Added to array and probas to preds castings in metrics
  • Loading branch information
RafaAyGar authored Jul 17, 2024
2 parents d7b3faf + b9c256a commit 107e49b
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions dlordinal/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def ranked_probability_score(y_true, y_proba):
>>> ranked_probability_score(y_true, y_pred)
0.506875
"""
y_true = np.array(y_true)
y_proba = np.array(y_proba)

y_oh = np.zeros(y_proba.shape)
y_oh[np.arange(len(y_true)), y_true] = 1

Expand Down Expand Up @@ -72,6 +75,13 @@ def minimum_sensitivity(y_true: np.ndarray, y_pred: np.ndarray) -> float:
>>> minimum_sensitivity(y_true, y_pred)
0.5
"""
y_true = np.array(y_true)
y_pred = np.array(y_pred)

if len(y_true.shape) > 1:
y_true = np.argmax(y_true, axis=1)
if len(y_pred.shape) > 1:
y_pred = np.argmax(y_pred, axis=1)

sensitivities = recall_score(y_true, y_pred, average=None)
return np.min(sensitivities)
Expand Down Expand Up @@ -101,6 +111,8 @@ def accuracy_off1(y_true: np.ndarray, y_pred: np.ndarray, labels=None) -> float:
>>> accuracy_off1(y_true, y_pred)
1.0
"""
y_true = np.array(y_true)
y_pred = np.array(y_pred)

if len(y_true.shape) > 1:
y_true = np.argmax(y_true, axis=1)
Expand Down Expand Up @@ -141,6 +153,13 @@ def gmsec(y_true: np.ndarray, y_pred: np.ndarray) -> float:
>>> gmec(y_true, y_pred)
0.5
"""
y_true = np.array(y_true)
y_pred = np.array(y_pred)

if len(y_true.shape) > 1:
y_true = np.argmax(y_true, axis=1)
if len(y_pred.shape) > 1:
y_pred = np.argmax(y_pred, axis=1)

sensitivities = recall_score(y_true, y_pred, average=None)
return np.sqrt(sensitivities[0] * sensitivities[-1])
Expand All @@ -161,6 +180,8 @@ def amae(y_true: np.ndarray, y_pred: np.ndarray):
amae : float
Average mean absolute error.
"""
y_true = np.array(y_true)
y_pred = np.array(y_pred)

if len(y_true.shape) > 1:
y_true = np.argmax(y_true, axis=1)
Expand Down Expand Up @@ -193,6 +214,8 @@ def mmae(y_true: np.ndarray, y_pred: np.ndarray):
mmae : float
Maximum mean absolute error.
"""
y_true = np.array(y_true)
y_pred = np.array(y_pred)

if len(y_true.shape) > 1:
y_true = np.argmax(y_true, axis=1)
Expand Down

0 comments on commit 107e49b

Please sign in to comment.