-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
62 lines (51 loc) · 2.22 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
class HybridLoss(torch.nn.Module):
def __init__(self, lamd=1e-1, spatial_tv=False, spectral_tv=False):
super(HybridLoss, self).__init__()
self.lamd = lamd
self.use_spatial_TV = spatial_tv
self.use_spectral_TV = spectral_tv
self.fidelity = torch.nn.L1Loss()
self.spatial = TVLoss(weight=1e-3)
self.spectral = TVLossSpectral(weight=1e-3)
def forward(self, y, gt):
loss = self.fidelity(y, gt)
spatial_TV = 0.0
spectral_TV = 0.0
if self.use_spatial_TV:
spatial_TV = self.spatial(y)
if self.use_spectral_TV:
spectral_TV = self.spectral(y)
total_loss = loss + spatial_TV + spectral_TV
return total_loss
# from https://github.com/jxgu1016/Total_Variation_Loss.pytorch with slight modifications
class TVLoss(torch.nn.Module):
def __init__(self, weight=1.0):
super(TVLoss, self).__init__()
self.TVLoss_weight = weight
def forward(self, x):
batch_size = x.size()[0]
h_x = x.size()[2]
w_x = x.size()[3]
count_h = self._tensor_size(x[:, :, 1:, :])
count_w = self._tensor_size(x[:, :, :, 1:])
# h_tv = torch.abs(x[:, :, 1:, :] - x[:, :, :h_x - 1, :]).sum()
# w_tv = torch.abs(x[:, :, :, 1:] - x[:, :, :, :w_x - 1]).sum()
h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
return self.TVLoss_weight * (h_tv / count_h + w_tv / count_w) / batch_size
def _tensor_size(self, t):
return t.size()[1] * t.size()[2] * t.size()[3]
class TVLossSpectral(torch.nn.Module):
def __init__(self, weight=1.0):
super(TVLossSpectral, self).__init__()
self.TVLoss_weight = weight
def forward(self, x):
batch_size = x.size()[0]
c_x = x.size()[1]
count_c = self._tensor_size(x[:, 1:, :, :])
# c_tv = torch.abs((x[:, 1:, :, :] - x[:, :c_x - 1, :, :])).sum()
c_tv = torch.pow((x[:, 1:, :, :] - x[:, :c_x - 1, :, :]), 2).sum()
return self.TVLoss_weight * 2 * (c_tv / count_c) / batch_size
def _tensor_size(self, t):
return t.size()[1] * t.size()[2] * t.size()[3]