Skip to content

Commit

Permalink
Implement different channel reductions for dice (#194)
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Jan 2, 2024
1 parent fad7fb6 commit d6cf34d
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 9 deletions.
16 changes: 16 additions & 0 deletions test/loss/test_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,22 @@ def test_dice_invalid(self):
with self.assertRaises(ValueError):
loss(x, y)

def test_dice_reduction(self):
from torch_em.loss import DiceLoss

shape = (1, 3, 32, 32)
x = torch.rand(*shape)
y = torch.rand(*shape)

for reduction in (None, "mean", "min", "max", "sum"):
loss = DiceLoss(reduce_channel=reduction)
lval = loss(x, y)
if reduction is None:
self.assertEqual(tuple(lval.shape), (3,))
else:
self.assertEqual(tuple(lval.shape), tuple())
self.assertEqual(lval.numel(), 1)


if __name__ == '__main__':
unittest.main()
38 changes: 29 additions & 9 deletions torch_em/loss/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ def flatten_samples(input_):
return flattened


def dice_score(input_, target, invert=False, channelwise=True, eps=1e-7):
def dice_score(input_, target, invert=False, channelwise=True, reduce_channel="sum", eps=1e-7):
if input_.shape != target.shape:
raise ValueError(f"Expect input and target of same shape, got: {input_.shape}, {target.shape}.")

if channelwise:
# Flatten input and target to have the shape (C, N),
# where N is the number of samples
Expand All @@ -39,48 +40,67 @@ def dice_score(input_, target, invert=False, channelwise=True, eps=1e-7):
channelwise_score = 2 * (numerator / denominator.clamp(min=eps))
if invert:
channelwise_score = 1. - channelwise_score
# Sum over the channels to compute the total score
score = channelwise_score.sum()

# Reduce the dice score over the channels to compute the overall dice score.
# (default is to use the sum)
if reduce_channel is None:
score = channelwise_score
elif reduce_channel == "sum":
score = channelwise_score.sum()
elif reduce_channel == "mean":
score = channelwise_score.mean()
elif reduce_channel == "max":
score = channelwise_score.max()
elif reduce_channel == "min":
score = channelwise_score.min()
else:
raise ValueError(f"Unsupported channel reduction {reduce_channel}")

else:
numerator = (input_ * target).sum()
denominator = (input_ * input_).sum() + (target * target).sum()
score = 2. * (numerator / denominator.clamp(min=eps))
if invert:
score = 1. - score

return score


class DiceLoss(nn.Module):
def __init__(self, channelwise=True, eps=1e-7):
def __init__(self, channelwise=True, eps=1e-7, reduce_channel="sum"):
if reduce_channel not in ("sum", "mean", "max", "min", None):
raise ValueError(f"Unsupported channel reduction {reduce_channel}")
super().__init__()
self.channelwise = channelwise
self.eps = eps
self.reduce_channel = reduce_channel

# all torch_em classes should store init kwargs to easily recreate the init call
self.init_kwargs = {"channelwise": channelwise, "eps": self.eps}
self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel}

def forward(self, input_, target):
return dice_score(input_, target,
invert=True, channelwise=self.channelwise,
eps=self.eps)
eps=self.eps, reduce_channel=self.reduce_channel)


class DiceLossWithLogits(nn.Module):
def __init__(self, channelwise=True, eps=1e-7):
def __init__(self, channelwise=True, eps=1e-7, reduce_channel="sum"):
super().__init__()
self.channelwise = channelwise
self.eps = eps

# all torch_em classes should store init kwargs to easily recreate the init call
self.init_kwargs = {"channelwise": channelwise, "eps": self.eps}
self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel}

def forward(self, input_, target):
return dice_score(
nn.functional.sigmoid(input_),
target,
invert=True,
channelwise=self.channelwise,
eps=self.eps
eps=self.eps,
reduce_channel=self.reduce_channel,
)


Expand Down

0 comments on commit d6cf34d

Please sign in to comment.