-
Notifications
You must be signed in to change notification settings - Fork 0
/
metrics.py
36 lines (28 loc) · 1.34 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from torchmetrics.classification import MultilabelPrecision, MultilabelF1Score, MultilabelAccuracy, MultilabelAUROC, MultilabelRecall
from torchmetrics.classification import BinaryAccuracy, BinaryAUROC, BinaryRecall, BinaryPrecision, BinaryF1Score
import numpy as np
import torch
def classification_binary_metrics(predictions, labels):
accuracy = BinaryAccuracy()
f1 = BinaryF1Score()
precision = BinaryPrecision()
recall = BinaryRecall()
auc = BinaryAUROC()
acc = accuracy(predictions, labels)
f1_score = f1(predictions, labels)
prec = precision(predictions, labels)
rec = recall(predictions, labels)
auc_score = auc(predictions, labels)
return acc, f1_score, prec, rec, auc_score
def classification_multilabel_metrics(predictions, labels):
acc = MultilabelAccuracy(num_labels=12, average='weighted')
f1 = MultilabelF1Score(num_labels=12, average='micro')
precision = MultilabelPrecision(num_labels=12, average='micro')
recall = MultilabelRecall(num_labels=12, average='micro')
auc = MultilabelAUROC(num_labels=12, average='micro')
label_accuracy = acc(predictions, labels)
f1_micro = f1(predictions, labels)
prec = precision(predictions, labels)
rec = recall(predictions, labels)
auc_score = auc(predictions, labels)
return label_accuracy, f1_micro, prec, rec, auc_score