-
Notifications
You must be signed in to change notification settings - Fork 0
/
pytorch_loss.py
90 lines (55 loc) · 2.71 KB
/
pytorch_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import torch.nn as nn
import torch.nn.functional as F
class DiceLoss(nn.Module):
def __init__(self):
super(DiceLoss, self).__init__()
self.loss_list = []
def forward(self, y_pred, y, smooth = 1e-15):
if (y.shape[1] == 1):
y_pred = torch.sigmoid(y_pred)
y_pred = y_pred.view(-1)
y = y.view(-1)
intersection = (y_pred * y).sum()
dice = (2*intersection + smooth) / (y_pred.sum() + y.sum() + smooth)
return 1 - dice
else:
for i in range(y.shape[1]):
y_prediction = F.softmax(y_pred, dim = 1)[:,i]
y_prediction = y_prediction.view(-1)
y_real = y[:,i]
y_real = y_real.view(-1)
intersection = (y_prediction * y_real).sum()
dice = (2*intersection + smooth) / (y_prediction.sum() + y_real.sum() + smooth)
self.loss_list.append(1 - dice)
return sum(self.loss_list) / (y.shape[1])
class IoULoss(nn.Module):
def __init__(self):
super(IoULoss, self).__init__()
self.loss_list = []
def forward(self, y_pred, y, smooth = 1e-15):
if (y.shape[1] == 1):
y_pred = torch.sigmoid(y_pred)
y_pred = y_pred.view(-1)
y = y.view(-1)
intersection = (y_pred * y).sum() + smooth
iou = (intersection) / (y_pred.sum() + y.sum() - intersection + smooth)
return (1-iou)
else:
for i in range(y.shape[1]):
y_prediction = torch.softmax(y_pred, dim = 1)[:,i]
y_prediction = y_prediction.view(-1)
y_real = y[:,i]
y_real = y_real.view(-1)
intersection = (y_prediction * y_real).sum() + smooth
iou = (intersection) / (y_prediction.sum() + y_real.sum() - intersection + smooth)
self.loss_list.append(1-iou)
return sum(self.loss_list) / y.shape[1]
y = torch.randn(1,10,512,512).to("cuda")
y_pred = torch.randn(1,10,512,512).to("cuda")
loss_dice = DiceLoss()
loss_dice = loss_dice(y_pred, y)
print("loss dice: ",loss_dice)
loss_iou = IoULoss()
loss_iou = loss_iou(y_pred, y)
print("Loss iou: ",loss_iou)