-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
48 lines (38 loc) · 1.51 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch.nn.functional as F
import torch.nn as nn
class MarginLoss(nn.Module):
def __init__(self, size_average=False, loss_lambda=0.5):
'''
Margin loss for digit existence
Eq. (4): L_k = T_k * max(0, m+ - ||v_k||)^2 + lambda * (1 - T_k) * max(0, ||v_k|| - m-)^2
Args:
size_average: should the losses be averaged (True) or summed (False) over observations for each minibatch.
loss_lambda: parameter for down-weighting the loss for missing digits
'''
super(MarginLoss, self).__init__()
self.size_average = size_average
self.m_plus = 0.9
self.m_minus = 0.1
self.loss_lambda = loss_lambda
def forward(self, inputs, labels):
L_k = labels * F.relu(self.m_plus - inputs)**2 + self.loss_lambda * (1 - labels) * F.relu(inputs - self.m_minus)**2
L_k = L_k.sum(dim=1)
if self.size_average:
return L_k.mean()
else:
return L_k.sum()
class CapsuleLoss(nn.Module):
def __init__(self, loss_lambda=0.5, recon_loss_scale=5e-4, size_average=False):
'''
Combined margin loss and reconstruction loss. Margin loss see above.
Sum squared error (SSE) was used as a reconstruction loss.
Args:
recon_loss_scale: param for scaling down the the reconstruction loss
size_average: if True, reconstruction loss becomes MSE instead of SSE
'''
super(CapsuleLoss, self).__init__()
self.size_average = size_average
self.margin_loss = MarginLoss(size_average=size_average, loss_lambda=loss_lambda)
def forward(self, inputs, labels):
margin_loss = self.margin_loss(inputs, labels)
return margin_loss