-
Notifications
You must be signed in to change notification settings - Fork 74
/
metric_strategy.py
82 lines (63 loc) · 2.49 KB
/
metric_strategy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import math
import torch
import torch.nn as nn
class Swish(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result = i * torch.sigmoid(i)
ctx.save_for_backward(i)
return result
@staticmethod
def backward(ctx, grad_output):
i = ctx.saved_variables[0]
sigmoid_i = torch.sigmoid(i)
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
class Swish_module(nn.Module):
def forward(self, x):
return Swish.apply(x)
class ArcMarginProduct_subcenter(nn.Module):
def __init__(self, in_features, out_features, k=3):
super().__init__()
self.weight = nn.Parameter(torch.FloatTensor(out_features*k, in_features))
self.reset_parameters()
self.k = k
self.out_features = out_features
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
def forward(self, features):
cosine_all = F.linear(F.normalize(features), F.normalize(self.weight))
cosine_all = cosine_all.view(-1, self.out_features, self.k)
cosine, _ = torch.max(cosine_all, dim=2)
return cosine
class DenseCrossEntropy(nn.Module):
def forward(self, x, target):
x = x.float()
target = target.float()
logprobs = torch.nn.functional.log_softmax(x, dim=-1)
loss = -logprobs * target
loss = loss.sum(-1)
return loss.mean()
class ArcFaceLossAdaptiveMargin(nn.modules.Module):
def __init__(self, margins, s=30.0):
super().__init__()
self.crit = DenseCrossEntropy()
self.s = s
self.margins = margins
def forward(self, logits, labels, out_dim):
ms = []
ms = self.margins[labels.cpu().numpy()]
cos_m = torch.from_numpy(np.cos(ms)).float().cuda()
sin_m = torch.from_numpy(np.sin(ms)).float().cuda()
th = torch.from_numpy(np.cos(math.pi - ms)).float().cuda()
mm = torch.from_numpy(np.sin(math.pi - ms) * ms).float().cuda()
labels = F.one_hot(labels, out_dim).float()
logits = logits.float()
cosine = logits
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
phi = cosine * cos_m.view(-1,1) - sine * sin_m.view(-1,1)
phi = torch.where(cosine > th.view(-1,1), phi, cosine - mm.view(-1,1))
output = (labels * phi) + ((1.0 - labels) * cosine)
output *= self.s
loss = self.crit(output, labels)
return loss