-
Notifications
You must be signed in to change notification settings - Fork 5
/
mask2former_head.py
163 lines (133 loc) · 5.97 KB
/
mask2former_head.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
try:
from mmdet.models.dense_heads import \
Mask2FormerHead as MMDET_Mask2FormerHead
except ModuleNotFoundError:
MMDET_Mask2FormerHead = BaseModule
from mmengine.structures import InstanceData
from torch import Tensor
from mmdepth.registry import MODELS
from mmdepth.structures.seg_data_sample import SegDataSample
from mmdepth.utils import ConfigType, SampleList
@MODELS.register_module()
class Mask2FormerHead(MMDET_Mask2FormerHead):
"""Implements the Mask2Former head.
See `Mask2Former: Masked-attention Mask Transformer for Universal Image
Segmentation <https://arxiv.org/abs/2112.01527>`_ for details.
Args:
num_classes (int): Number of classes. Default: 150.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
ignore_index (int): The label index to be ignored. Default: 255.
"""
def __init__(self,
num_classes,
align_corners=False,
ignore_index=255,
**kwargs):
super().__init__(**kwargs)
self.num_classes = num_classes
self.align_corners = align_corners
self.out_channels = num_classes
self.ignore_index = ignore_index
feat_channels = kwargs['feat_channels']
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
def _seg_data_to_instance_data(self, batch_data_samples: SampleList):
"""Perform forward propagation to convert paradigm from MMSegmentation
to MMDetection to ensure ``MMDET_Mask2FormerHead`` could be called
normally. Specifically, ``batch_gt_instances`` would be added.
Args:
batch_data_samples (List[:obj:`SegDataSample`]): The Data
Samples. It usually includes information such as
`gt_sem_seg`.
Returns:
tuple[Tensor]: A tuple contains two lists.
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``labels``, each is
unique ground truth label id of images, with
shape (num_gt, ) and ``masks``, each is ground truth
masks of each instances of a image, shape (num_gt, h, w).
- batch_img_metas (list[dict]): List of image meta information.
"""
batch_img_metas = []
batch_gt_instances = []
for data_sample in batch_data_samples:
batch_img_metas.append(data_sample.metainfo)
gt_sem_seg = data_sample.gt_sem_seg.data
classes = torch.unique(
gt_sem_seg,
sorted=False,
return_inverse=False,
return_counts=False)
# remove ignored region
gt_labels = classes[classes != self.ignore_index]
masks = []
for class_id in gt_labels:
masks.append(gt_sem_seg == class_id)
if len(masks) == 0:
gt_masks = torch.zeros(
(0, gt_sem_seg.shape[-2],
gt_sem_seg.shape[-1])).to(gt_sem_seg).long()
else:
gt_masks = torch.stack(masks).squeeze(1).long()
instance_data = InstanceData(labels=gt_labels, masks=gt_masks)
batch_gt_instances.append(instance_data)
return batch_gt_instances, batch_img_metas
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList,
train_cfg: ConfigType) -> dict:
"""Perform forward propagation and loss calculation of the decoder head
on the features of the upstream network.
Args:
x (tuple[Tensor]): Multi-level features from the upstream
network, each is a 4D-tensor.
batch_data_samples (List[:obj:`SegDataSample`]): The Data
Samples. It usually includes information such as
`gt_sem_seg`.
train_cfg (ConfigType): Training config.
Returns:
dict[str, Tensor]: a dictionary of loss components.
"""
# batch SegDataSample to InstanceDataSample
batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data(
batch_data_samples)
# forward
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
# loss
losses = self.loss_by_feat(all_cls_scores, all_mask_preds,
batch_gt_instances, batch_img_metas)
return losses
def predict(self, x: Tuple[Tensor], batch_img_metas: List[dict],
test_cfg: ConfigType) -> Tuple[Tensor]:
"""Test without augmentaton.
Args:
x (tuple[Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
batch_img_metas (List[:obj:`SegDataSample`]): The Data
Samples. It usually includes information such as
`gt_sem_seg`.
test_cfg (ConfigType): Test config.
Returns:
Tensor: A tensor of segmentation mask.
"""
batch_data_samples = [
SegDataSample(metainfo=metainfo) for metainfo in batch_img_metas
]
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
mask_cls_results = all_cls_scores[-1]
mask_pred_results = all_mask_preds[-1]
if 'pad_shape' in batch_img_metas[0]:
size = batch_img_metas[0]['pad_shape']
else:
size = batch_img_metas[0]['img_shape']
# upsample mask
mask_pred_results = F.interpolate(
mask_pred_results, size=size, mode='bilinear', align_corners=False)
cls_score = F.softmax(mask_cls_results, dim=-1)[..., :-1]
mask_pred = mask_pred_results.sigmoid()
seg_logits = torch.einsum('bqc, bqhw->bchw', cls_score, mask_pred)
return seg_logits