Skip to content

Commit

Permalink
fix MIC-DKFZ#2181; please install latest (master) batchgeneratorsv2
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianIsensee committed May 28, 2024
1 parent 086ae96 commit 7141d4f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
10 changes: 8 additions & 2 deletions nnunetv2/training/loss/compound_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,20 @@ def __init__(self, bce_kwargs, soft_dice_kwargs, weight_ce=1, weight_dice=1, use
def forward(self, net_output: torch.Tensor, target: torch.Tensor):
if self.use_ignore_label:
# target is one hot encoded here. invert it so that it is True wherever we can compute the loss
mask = (1 - target[:, -1:]).bool()
if target.dtype == torch.bool:
mask = ~target[:, -1:]
else:
mask = (1 - target[:, -1:]).bool()
# remove ignore channel now that we have the mask
target_regions = torch.clone(target[:, :-1])
# why did we use clone in the past? Should have documented that...
# target_regions = torch.clone(target[:, :-1])
target_regions = target[:, :-1]
else:
target_regions = target
mask = None

dc_loss = self.dc(net_output, target_regions, loss_mask=mask)
target_regions = target_regions.float()
if mask is not None:
ce_loss = (self.ce(net_output, target_regions) * mask).sum() / torch.clip(mask.sum(), min=1e-8)
else:
Expand Down
6 changes: 3 additions & 3 deletions nnunetv2/training/loss/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,13 @@ def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False):
# if this is the case then gt is probably already a one hot encoding
y_onehot = gt
else:
y_onehot = torch.zeros(net_output.shape, device=net_output.device)
y_onehot = torch.zeros(net_output.shape, device=net_output.device, dtype=torch.bool)
y_onehot.scatter_(1, gt.long(), 1)

tp = net_output * y_onehot
fp = net_output * (1 - y_onehot)
fp = net_output * (~y_onehot)
fn = (1 - net_output) * y_onehot
tn = (1 - net_output) * (1 - y_onehot)
tn = (1 - net_output) * (~y_onehot)

if mask is not None:
with torch.no_grad():
Expand Down
5 changes: 4 additions & 1 deletion nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,7 +1043,10 @@ def validation_step(self, batch: dict) -> dict:
# CAREFUL that you don't rely on target after this line!
target[target == self.label_manager.ignore_label] = 0
else:
mask = 1 - target[:, -1:]
if target.dtype == torch.bool:
mask = ~target[:, -1:]
else:
mask = 1 - target[:, -1:]
# CAREFUL that you don't rely on target after this line!
target = target[:, :-1]
else:
Expand Down

0 comments on commit 7141d4f

Please sign in to comment.