diff --git a/dlordinal/metrics/metrics.py b/dlordinal/metrics/metrics.py index b2ca3da..33a35f3 100644 --- a/dlordinal/metrics/metrics.py +++ b/dlordinal/metrics/metrics.py @@ -35,6 +35,9 @@ def ranked_probability_score(y_true, y_proba): >>> ranked_probability_score(y_true, y_pred) 0.506875 """ + y_true = np.array(y_true) + y_proba = np.array(y_proba) + y_oh = np.zeros(y_proba.shape) y_oh[np.arange(len(y_true)), y_true] = 1 @@ -72,6 +75,13 @@ def minimum_sensitivity(y_true: np.ndarray, y_pred: np.ndarray) -> float: >>> minimum_sensitivity(y_true, y_pred) 0.5 """ + y_true = np.array(y_true) + y_pred = np.array(y_pred) + + if len(y_true.shape) > 1: + y_true = np.argmax(y_true, axis=1) + if len(y_pred.shape) > 1: + y_pred = np.argmax(y_pred, axis=1) sensitivities = recall_score(y_true, y_pred, average=None) return np.min(sensitivities) @@ -101,6 +111,8 @@ def accuracy_off1(y_true: np.ndarray, y_pred: np.ndarray, labels=None) -> float: >>> accuracy_off1(y_true, y_pred) 1.0 """ + y_true = np.array(y_true) + y_pred = np.array(y_pred) if len(y_true.shape) > 1: y_true = np.argmax(y_true, axis=1) @@ -141,6 +153,13 @@ def gmsec(y_true: np.ndarray, y_pred: np.ndarray) -> float: >>> gmec(y_true, y_pred) 0.5 """ + y_true = np.array(y_true) + y_pred = np.array(y_pred) + + if len(y_true.shape) > 1: + y_true = np.argmax(y_true, axis=1) + if len(y_pred.shape) > 1: + y_pred = np.argmax(y_pred, axis=1) sensitivities = recall_score(y_true, y_pred, average=None) return np.sqrt(sensitivities[0] * sensitivities[-1]) @@ -161,6 +180,8 @@ def amae(y_true: np.ndarray, y_pred: np.ndarray): amae : float Average mean absolute error. """ + y_true = np.array(y_true) + y_pred = np.array(y_pred) if len(y_true.shape) > 1: y_true = np.argmax(y_true, axis=1) @@ -193,6 +214,8 @@ def mmae(y_true: np.ndarray, y_pred: np.ndarray): mmae : float Maximum mean absolute error. """ + y_true = np.array(y_true) + y_pred = np.array(y_pred) if len(y_true.shape) > 1: y_true = np.argmax(y_true, axis=1)