forked from PeiqinZhuang/API-Net
-
Notifications
You must be signed in to change notification settings - Fork 2
/
orthogonalprojectionloss.py
33 lines (23 loc) · 1.12 KB
/
orthogonalprojectionloss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn as nn
import torch.nn.functional as F
class OrthogonalProjectionLoss(nn.Module):
def __init__(self, gamma=0.5):
super(OrthogonalProjectionLoss, self).__init__()
self.gamma = gamma
def forward(self, features, labels=None):
# print(f'features.shape {features.shape}')
# print(f'labels.shape {labels.shape}')
device = (torch.device('cuda') if features.is_cuda else torch.device('cpu'))
# features are normalized
features = F.normalize(features, p=2, dim=1)
labels = labels[:, None] # extend dim
mask = torch.eq(labels, labels.t()).bool().to(device)
eye = torch.eye(mask.shape[0], mask.shape[1]).bool().to(device)
mask_pos = mask.masked_fill(eye, 0).float()
mask_neg = (~mask).float()
dot_prod = torch.matmul(features, features.t())
pos_pairs_mean = (mask_pos * dot_prod).sum() / (mask_pos.sum() + 1e-6)
neg_pairs_mean = (mask_neg * dot_prod).sum() / (mask_neg.sum() + 1e-6) # TODO: removed abs
loss = (1.0 - pos_pairs_mean) + self.gamma * neg_pairs_mean
return loss