diff --git a/src/eva/vision/losses/dice.py b/src/eva/vision/losses/dice.py index 8e6133b3..d5d31d17 100644 --- a/src/eva/vision/losses/dice.py +++ b/src/eva/vision/losses/dice.py @@ -45,9 +45,6 @@ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: inputs, targets = _apply_ignore_index(inputs, targets, self.ignore_index) targets = _to_one_hot(targets, num_classes=inputs.shape[1]) - if targets.ndim == 3: - targets = one_hot(targets[:, None, ...], num_classes=inputs.shape[1]) - return super().forward(inputs, targets)