forked from open-mmlab/mmsegmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dice_loss.py
202 lines (180 loc) · 7.85 KB
/
dice_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
import torch
import torch.nn as nn
from mmseg.registry import MODELS
from .utils import weight_reduce_loss
def _expand_onehot_labels_dice(pred: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
"""Expand onehot labels to match the size of prediction.
Args:
pred (torch.Tensor): The prediction, has a shape (N, num_class, H, W).
target (torch.Tensor): The learning label of the prediction,
has a shape (N, H, W).
Returns:
torch.Tensor: The target after one-hot encoding,
has a shape (N, num_class, H, W).
"""
num_classes = pred.shape[1]
one_hot_target = torch.clamp(target, min=0, max=num_classes)
one_hot_target = torch.nn.functional.one_hot(one_hot_target,
num_classes + 1)
one_hot_target = one_hot_target[..., :num_classes].permute(0, 3, 1, 2)
return one_hot_target
def dice_loss(pred: torch.Tensor,
target: torch.Tensor,
weight: Union[torch.Tensor, None],
eps: float = 1e-3,
reduction: Union[str, None] = 'mean',
naive_dice: Union[bool, None] = False,
avg_factor: Union[int, None] = None,
ignore_index: Union[int, None] = 255) -> float:
"""Calculate dice loss, there are two forms of dice loss is supported:
- the one proposed in `V-Net: Fully Convolutional Neural
Networks for Volumetric Medical Image Segmentation
<https://arxiv.org/abs/1606.04797>`_.
- the dice loss in which the power of the number in the
denominator is the first power instead of the second
power.
Args:
pred (torch.Tensor): The prediction, has a shape (n, *)
target (torch.Tensor): The learning label of the prediction,
shape (n, *), same shape of pred.
weight (torch.Tensor, optional): The weight of loss for each
prediction, has a shape (n,). Defaults to None.
eps (float): Avoid dividing by zero. Default: 1e-3.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
Options are "none", "mean" and "sum".
naive_dice (bool, optional): If false, use the dice
loss defined in the V-Net paper, otherwise, use the
naive dice loss in which the power of the number in the
denominator is the first power instead of the second
power.Defaults to False.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
ignore_index (int, optional): The label index to be ignored.
Defaults to 255.
"""
if ignore_index is not None:
num_classes = pred.shape[1]
pred = pred[:, torch.arange(num_classes) != ignore_index, :, :]
target = target[:, torch.arange(num_classes) != ignore_index, :, :]
assert pred.shape[1] != 0 # if the ignored index is the only class
input = pred.flatten(1)
target = target.flatten(1).float()
a = torch.sum(input * target, 1)
if naive_dice:
b = torch.sum(input, 1)
c = torch.sum(target, 1)
d = (2 * a + eps) / (b + c + eps)
else:
b = torch.sum(input * input, 1) + eps
c = torch.sum(target * target, 1) + eps
d = (2 * a) / (b + c)
loss = 1 - d
if weight is not None:
assert weight.ndim == loss.ndim
assert len(weight) == len(pred)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
@MODELS.register_module()
class DiceLoss(nn.Module):
def __init__(self,
use_sigmoid=True,
activate=True,
reduction='mean',
naive_dice=False,
loss_weight=1.0,
ignore_index=255,
eps=1e-3,
loss_name='loss_dice'):
"""Compute dice loss.
Args:
use_sigmoid (bool, optional): Whether to the prediction is
used for sigmoid or softmax. Defaults to True.
activate (bool): Whether to activate the predictions inside,
this will disable the inside sigmoid operation.
Defaults to True.
reduction (str, optional): The method used
to reduce the loss. Options are "none",
"mean" and "sum". Defaults to 'mean'.
naive_dice (bool, optional): If false, use the dice
loss defined in the V-Net paper, otherwise, use the
naive dice loss in which the power of the number in the
denominator is the first power instead of the second
power. Defaults to False.
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
ignore_index (int, optional): The label index to be ignored.
Default: 255.
eps (float): Avoid dividing by zero. Defaults to 1e-3.
loss_name (str, optional): Name of the loss item. If you want this
loss item to be included into the backward graph, `loss_` must
be the prefix of the name. Defaults to 'loss_dice'.
"""
super().__init__()
self.use_sigmoid = use_sigmoid
self.reduction = reduction
self.naive_dice = naive_dice
self.loss_weight = loss_weight
self.eps = eps
self.activate = activate
self.ignore_index = ignore_index
self._loss_name = loss_name
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None,
ignore_index=255,
**kwargs):
"""Forward function.
Args:
pred (torch.Tensor): The prediction, has a shape (n, *).
target (torch.Tensor): The label of the prediction,
shape (n, *), same shape of pred.
weight (torch.Tensor, optional): The weight of loss for each
prediction, has a shape (n,). Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns:
torch.Tensor: The calculated loss
"""
one_hot_target = target
if (pred.shape != target.shape):
one_hot_target = _expand_onehot_labels_dice(pred, target)
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.activate:
if self.use_sigmoid:
pred = pred.sigmoid()
elif pred.shape[1] != 1:
# softmax does not work when there is only 1 class
pred = pred.softmax(dim=1)
loss = self.loss_weight * dice_loss(
pred,
one_hot_target,
weight,
eps=self.eps,
reduction=reduction,
naive_dice=self.naive_dice,
avg_factor=avg_factor,
ignore_index=self.ignore_index)
return loss
@property
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name