-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
107 lines (87 loc) · 3.81 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import numpy as np
from scipy.optimize import curve_fit
from scipy import stats
import torch
import torch.nn.functional as F
def logistic_func(X, bayta1, bayta2, bayta3, bayta4):
logisticPart = 1 + np.exp(np.negative(np.divide(X - bayta3, np.abs(bayta4))))
yhat = bayta2 + np.divide(bayta1 - bayta2, logisticPart)
return yhat
def fit_function(y_label, y_output):
beta = [np.max(y_label), np.min(y_label), np.mean(y_output), 0.5]
popt, _ = curve_fit(logistic_func, y_output, \
y_label, p0=beta, maxfev=100000000)
y_output_logistic = logistic_func(y_output, *popt)
return y_output_logistic
def performance_fit(y_label, y_output):
y_output_logistic = fit_function(y_label, y_output)
PLCC = stats.pearsonr(y_output_logistic, y_label)[0]
SRCC = stats.spearmanr(y_output, y_label)[0]
KRCC = stats.kendalltau(y_output, y_label)[0]
RMSE = np.sqrt(((y_output_logistic-y_label) ** 2).mean())
return PLCC, SRCC, KRCC, RMSE
def performance_no_fit(y_label, y_output):
PLCC = stats.pearsonr(y_output, y_label)[0]
SRCC = stats.spearmanr(y_output, y_label)[0]
KRCC = stats.stats.kendalltau(y_output, y_label)[0]
RMSE = np.sqrt(((y_label-y_label) ** 2).mean())
return PLCC, SRCC, KRCC, RMSE
class L1RankLoss(torch.nn.Module):
def __init__(self, **kwargs):
super(L1RankLoss, self).__init__()
self.l1_w = kwargs.get("l1_w", 1)
self.rank_w = kwargs.get("rank_w", 1)
self.hard_thred = kwargs.get("hard_thred", 1)
self.use_margin = kwargs.get("use_margin", False)
self.batchsize = kwargs.get("batchsize", 8);
def forward(self, preds, gts):
# preds = torch.reshape(preds, [self.batchsize, -1])
# preds = torch.mean(preds, dim=1)
gts = gts.view(-1)
# l1 loss
l1_loss = F.l1_loss(preds, gts) * self.l1_w
# simple rank
n = len(preds)
preds = preds.unsqueeze(0).repeat(n, 1)
preds_t = preds.t()
img_label = gts.unsqueeze(0).repeat(n, 1)
img_label_t = img_label.t()
masks = torch.sign(img_label - img_label_t)
masks_hard = (torch.abs(img_label - img_label_t) < self.hard_thred) & (torch.abs(img_label - img_label_t) > 0)
if self.use_margin:
rank_loss = masks_hard * torch.relu(torch.abs(img_label - img_label_t) - masks * (preds - preds_t))
else:
rank_loss = masks_hard * torch.relu(- masks * (preds - preds_t))
rank_loss = rank_loss.sum() / (masks_hard.sum() + 1e-08)
loss_total = l1_loss + rank_loss * self.rank_w
return loss_total
class L1RankLoss1(torch.nn.Module):
"""
L1 loss + Rank loss
"""
def __init__(self, **kwargs):
super(L1RankLoss1, self).__init__()
self.l1_w = kwargs.get("l1_w", 1)
self.rank_w = kwargs.get("rank_w", 1)
self.hard_thred = kwargs.get("hard_thred", 1)
self.use_margin = kwargs.get("use_margin", False)
def forward(self, preds, gts):
preds = preds.view(-1)
gts = gts.view(-1)
# l1 loss
l1_loss = F.l1_loss(preds, gts) * self.l1_w
# simple rank
n = len(preds)
preds = preds.unsqueeze(0).repeat(n, 1)
preds_t = preds.t()
img_label = gts.unsqueeze(0).repeat(n, 1)
img_label_t = img_label.t()
masks = torch.sign(img_label - img_label_t)
masks_hard = (torch.abs(img_label - img_label_t) < self.hard_thred) & (torch.abs(img_label - img_label_t) > 0)
if self.use_margin:
rank_loss = masks_hard * torch.relu(torch.abs(img_label - img_label_t) - masks * (preds - preds_t))
else:
rank_loss = masks_hard * torch.relu(- masks * (preds - preds_t))
rank_loss = rank_loss.sum() / (masks_hard.sum() + 1e-08)
loss_total = l1_loss + rank_loss * self.rank_w
return loss_total