-
Notifications
You must be signed in to change notification settings - Fork 10
/
util.py
61 lines (45 loc) · 1.75 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn.functional as F
from torch import nn, Tensor
class CircleLoss(nn.Module):
def __init__(self, m: float, gamma: float) -> None:
super(CircleLoss, self).__init__()
self.m = m
self.gamma = gamma
self.soft_plus = nn.Softplus()
def forward(self, sp: Tensor, sn: Tensor) -> Tensor:
ap = torch.clamp_min(- sp.detach() + 1 + self.m, min=0.)
an = torch.clamp_min(sn.detach() + self.m, min=0.)
delta_p = 1 - self.m
delta_n = self.m
logit_p = - ap * (sp - delta_p) * self.gamma
logit_n = an * (sn - delta_n) * self.gamma
loss = self.soft_plus(torch.logsumexp(logit_n, dim=0) + torch.logsumexp(logit_p, dim=0))
return loss
def cal_loss(pred, gold, smoothing=True):
''' Calculate cross entropy loss, apply label smoothing if needed. '''
gold = gold.contiguous().view(-1)
if smoothing:
eps = 0.2
n_class = pred.size(1)
one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, dim=1)
loss = -(one_hot * log_prb).sum(dim=1).mean()
else:
loss = F.cross_entropy(pred, gold, reduction='mean')
return loss
class IOStream():
def __init__(self, path):
self.f = open(path, 'a')
def cprint(self, text):
print(text)
self.f.write(text+'\n')
self.f.flush()
def close(self):
self.f.close()
def huber_loss(error, delta):
abs_error = torch.abs(error)
quadratic = torch.min(abs_error, torch.full_like(abs_error, fill_value=delta))
losses = 0.5 * (quadratic ** 2) + delta * (abs_error - quadratic)
return torch.mean(losses)