Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DualTaskLoss not working #72

Open
WaterKnight1998 opened this issue May 18, 2020 · 4 comments
Open

DualTaskLoss not working #72

WaterKnight1998 opened this issue May 18, 2020 · 4 comments

Comments

@WaterKnight1998
Copy link

~/Documents/pro1/seg/utils/loss_gscnn.py in forward(self, inputs, targets)
    344         losses['edge_loss'] = self.edge_weight * 20 * self.bce2d(edgein, edgemask)
    345         losses['att_loss'] = self.att_weight * self.edge_attention(segin, segmask, edgein)
--> 346         losses['dual_loss'] = self.dual_weight * self.dual_task(segin, segmask)
    347 
    348         return losses

~/anaconda3/envs/seg/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    556             result = self._slow_forward(*input, **kwargs)
    557         else:
--> 558             result = self.forward(*input, **kwargs)
    559         for hook in self._forward_hooks.values():
    560             hook_result = hook(self, input, result)

~/Documents/pro1/seg/utils/loss_gscnn.py in forward(self, input_logits, gts, ignore_pixel)
    240         input_logits = torch.where(ignore_mask.view(N, 1, H, W).expand(N, 19, H, W),
    241                                    torch.zeros(N,C,H,W).cuda(),
--> 242                                    input_logits)
    243         gt_semantic_masks = gts.detach()
    244         gt_semantic_masks = torch.where(ignore_mask, torch.zeros(N,H,W).long().cuda(), gt_semantic_masks)

RuntimeError: The size of tensor a (19) must match the size of tensor b (2) at non-singleton dimension 1
@WaterKnight1998
Copy link
Author

I updated code however F.l1_loss are not matching:

N, C, H, W = input_logits.shape
        th = 1e-8  # 1e-10
        eps = 1e-10
        ignore_mask = (gts == ignore_pixel).detach()
        input_logits = torch.where(ignore_mask.view(N, 1, H, W).expand(N, C, H, W),
                                   torch.zeros(N,C,H,W).cuda(),
                                   input_logits)
        gt_semantic_masks = gts.detach()
        gt_semantic_masks = torch.where(ignore_mask, torch.zeros(N,H,W).long().cuda(), gt_semantic_masks)
        gt_semantic_masks = _one_hot_embedding(gt_semantic_masks, 19).detach()

        g = _gumbel_softmax_sample(input_logits.view(N, C, -1), tau=0.5)
        g = g.reshape((N, C, H, W))
        g = compute_grad_mag(g, cuda=self._cuda)
 
        g_hat = compute_grad_mag(gt_semantic_masks, cuda=self._cuda)

        g = g.view(N, -1)
        g_hat = g_hat.reshape(N, -1)
        loss_ewise = F.l1_loss(g, g_hat, reduction='none', reduce=False)

        p_plus_g_mask = (g >= th).detach().float()
        loss_p_plus_g = torch.sum(loss_ewise * p_plus_g_mask) / (torch.sum(p_plus_g_mask) + eps)

        p_plus_g_hat_mask = (g_hat >= th).detach().float()
        loss_p_plus_g_hat = torch.sum(loss_ewise * p_plus_g_hat_mask) / (torch.sum(p_plus_g_hat_mask) + eps)

        total_loss = 0.5 * loss_p_plus_g + 0.5 * loss_p_plus_g_hat

@WaterKnight1998
Copy link
Author

@ayinaaaaaa i am in binary segmentation!!

@ShreyasHavaldar7
Copy link

ShreyasHavaldar7 commented Jun 19, 2020

Hey @WaterKnight1998,
I was facing the same issue in a similar situation. One change I believe you should make is to replace the 19 by 2 as you are dealing with binary segmentation, thus 2 classes. If I assume correctly, the 19 corresponds to the 19 classes in cityscapes dataset. Replacing
gt_semantic_masks = _one_hot_embedding(gt_semantic_masks, 19).detach()
with
gt_semantic_masks = _one_hot_embedding(gt_semantic_masks, 2).detach()
works.

This should be enough to resolve the error.

@reacher-l
Copy link

gt_semantic_masks = _one_hot_embedding(gt_semantic_masks, 2).detach() here 2 can be your classes, i use this pass

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants