Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] remove expand from mmdet rewriter #371

Merged
merged 1 commit into from
Apr 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions mmdeploy/codebase/mmdet/core/post_processing/bbox_nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def select_nms_index(scores: torch.Tensor,
_, topk_inds = batched_dets[:, :, -1].sort(dim=1, descending=True)
topk_batch_inds = torch.arange(
batch_size, dtype=topk_inds.dtype,
device=topk_inds.device).view(-1, 1).expand_as(topk_inds)
device=topk_inds.device).view(-1, 1)
batched_dets = batched_dets[topk_batch_inds, topk_inds, ...]
batched_labels = batched_labels[topk_batch_inds, topk_inds, ...]

Expand Down Expand Up @@ -96,8 +96,7 @@ def _multiclass_nms(boxes: Tensor,
if pre_top_k > 0:
max_scores, _ = scores.max(-1)
_, topk_inds = max_scores.topk(pre_top_k)
batch_inds = torch.arange(batch_size).view(
-1, 1).expand_as(topk_inds).long()
batch_inds = torch.arange(batch_size).view(-1, 1).long()
boxes = boxes[batch_inds, topk_inds, :]
scores = scores[batch_inds, topk_inds, :]

Expand Down Expand Up @@ -298,8 +297,7 @@ def multiclass_nms__torchscript(ctx,
if pre_top_k > 0:
max_scores, _ = scores.max(-1)
_, topk_inds = max_scores.topk(pre_top_k)
batch_inds = torch.arange(batch_size).view(
-1, 1).expand_as(topk_inds).long()
batch_inds = torch.arange(batch_size).view(-1, 1).long()
boxes = boxes[batch_inds, topk_inds, ...]
scores = scores[batch_inds, topk_inds, :]
num_boxes = scores.shape[1]
Expand Down
15 changes: 4 additions & 11 deletions mmdeploy/codebase/mmdet/deploy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,18 +113,11 @@ def pad_with_value(x: Tensor,
Returns:
Tensor: Padded tensor.
"""
num_dims = len(x.shape)
pad_slice = (slice(None, None, None), ) * num_dims
pad_slice = pad_slice[:pad_dim] + (slice(0, 1,
1), ) + pad_slice[pad_dim + 1:]
repeat_size = [1] * num_dims
repeat_size[pad_dim] = pad_size

x_pad = x.__getitem__(pad_slice)
x_shape = list(x.shape)
pad_shape = x_shape[:pad_dim] + [pad_size] + x_shape[pad_dim + 1:]
x_pad = x.new_zeros(pad_shape)
if pad_value is not None:
x_pad = x_pad * 0 + pad_value

x_pad = x_pad.repeat(*repeat_size)
x_pad = x_pad + pad_value
x = torch.cat([x, x_pad], dim=pad_dim)
return x

Expand Down
8 changes: 4 additions & 4 deletions mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def base_dense_head__get_bbox(ctx,
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes, dtype=bbox_preds[0].dtype, device=bbox_preds[0].device)
mlvl_priors = [priors.unsqueeze(0) for priors in mlvl_priors]

mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
Expand Down Expand Up @@ -108,7 +109,6 @@ def base_dense_head__get_bbox(ctx,
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
if not is_dynamic_flag:
priors = priors.data
priors = priors.expand(batch_size, -1, priors.size(-1))
if pre_topk > 0:
priors = pad_with_value_if_necessary(priors, 1, pre_topk)
bbox_pred = pad_with_value_if_necessary(bbox_pred, 1, pre_topk)
Expand All @@ -128,9 +128,9 @@ def base_dense_head__get_bbox(ctx,
max_scores, _ = nms_pre_score[..., :-1].max(-1)
_, topk_inds = max_scores.topk(pre_topk)
batch_inds = torch.arange(
batch_size,
device=bbox_pred.device).view(-1, 1).expand_as(topk_inds)
priors = priors[batch_inds, topk_inds, :]
batch_size, device=bbox_pred.device).unsqueeze(-1)
prior_inds = batch_inds.new_zeros((1, 1))
priors = priors[prior_inds, topk_inds, :]
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
scores = scores[batch_inds, topk_inds, :]
if with_score_factors:
Expand Down
8 changes: 4 additions & 4 deletions mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def gfl_head__get_bbox(ctx,
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes, dtype=bbox_preds[0].dtype, device=bbox_preds[0].device)
mlvl_priors = [priors.unsqueeze(0) for priors in mlvl_priors]

mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
Expand Down Expand Up @@ -110,7 +111,6 @@ def gfl_head__get_bbox(ctx,
bbox_pred.permute(0, 2, 3, 1)) * stride[0]
if not is_dynamic_flag:
priors = priors.data
priors = priors.expand(batch_size, -1, priors.size(-1))
if pre_topk > 0:
if with_score_factors:
nms_pre_score = nms_pre_score * score_factors
Expand All @@ -130,9 +130,9 @@ def gfl_head__get_bbox(ctx,
max_scores, _ = nms_pre_score[..., :-1].max(-1)
_, topk_inds = max_scores.topk(pre_topk)
batch_inds = torch.arange(
batch_size,
device=bbox_pred.device).view(-1, 1).expand_as(topk_inds)
priors = priors[batch_inds, topk_inds, :]
batch_size, device=bbox_pred.device).unsqueeze(-1)
prior_inds = batch_inds.new_zeros((1, 1))
priors = priors[prior_inds, topk_inds, :]
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
scores = scores[batch_inds, topk_inds, :]
if with_score_factors:
Expand Down
8 changes: 4 additions & 4 deletions mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def rpn_head__get_bboxes(ctx,
if not is_dynamic_flag:
anchors = anchors.data

anchors = anchors.expand_as(bbox_pred)
anchors = anchors.unsqueeze(0)

# topk in tensorrt does not support shape<k
# concate zero to enable topk,
Expand All @@ -104,9 +104,9 @@ def rpn_head__get_bboxes(ctx,

if pre_topk > 0:
_, topk_inds = scores.squeeze(2).topk(pre_topk)
batch_inds = torch.arange(
batch_size, device=device).view(-1, 1).expand_as(topk_inds)
anchors = anchors[batch_inds, topk_inds, :]
batch_inds = torch.arange(batch_size, device=device).unsqueeze(-1)
prior_inds = topk_inds.new_zeros((1, 1))
anchors = anchors[prior_inds, topk_inds, :]
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
scores = scores[batch_inds, topk_inds, :]
mlvl_valid_bboxes.append(bbox_pred)
Expand Down
19 changes: 9 additions & 10 deletions mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ def yolov3_head__get_bboxes(ctx,
# use static anchor if input shape is static
if not is_dynamic_flag:
multi_lvl_anchor = multi_lvl_anchor.data
multi_lvl_anchor = multi_lvl_anchor.unsqueeze(0).expand_as(
pred_map_boxes)
multi_lvl_anchor = multi_lvl_anchor.unsqueeze(0)
bbox_pred = self.bbox_coder.decode(multi_lvl_anchor, pred_map_boxes,
stride)
# conf and cls
Expand All @@ -100,8 +99,7 @@ def yolov3_head__get_bboxes(ctx,
if pre_topk > 0:
_, topk_inds = conf_pred.topk(pre_topk)
batch_inds = torch.arange(
batch_size, device=device).view(-1,
1).expand_as(topk_inds).long()
batch_size, device=device).unsqueeze(-1).long()
# Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
transformed_inds = (bbox_pred.shape[1] * batch_inds + topk_inds)
bbox_pred = bbox_pred.reshape(-1, 4)[transformed_inds, :].reshape(
Expand Down Expand Up @@ -129,14 +127,15 @@ def yolov3_head__get_bboxes(ctx,

# follow original pipeline of YOLOv3
if confidence_threshold > 0:
mask = (batch_mlvl_conf_scores >= confidence_threshold).float()
batch_mlvl_conf_scores *= mask
mask = batch_mlvl_conf_scores >= confidence_threshold
batch_mlvl_conf_scores = batch_mlvl_conf_scores.where(
mask, batch_mlvl_conf_scores.new_zeros(1))
if score_threshold > 0:
mask = (batch_mlvl_scores > score_threshold).float()
batch_mlvl_scores *= mask
mask = batch_mlvl_scores > score_threshold
batch_mlvl_scores = batch_mlvl_scores.where(
mask, batch_mlvl_scores.new_zeros(1))

batch_mlvl_conf_scores = batch_mlvl_conf_scores.unsqueeze(2).expand_as(
batch_mlvl_scores)
batch_mlvl_conf_scores = batch_mlvl_conf_scores.unsqueeze(2)
batch_mlvl_scores = batch_mlvl_scores * batch_mlvl_conf_scores

if with_nms:
Expand Down
3 changes: 1 addition & 2 deletions mmdeploy/codebase/mmdet/models/roi_heads/bbox_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ def bbox_head__get_bboxes(ctx,
# only keep boxes with the max scores
max_inds = scores.reshape(-1, self.num_classes).argmax(1, keepdim=True)
bboxes = bboxes.reshape(-1, self.num_classes, 4)
dim0_inds = torch.arange(
bboxes.shape[0], device=device).view(-1, 1).expand_as(max_inds)
dim0_inds = torch.arange(bboxes.shape[0], device=device).unsqueeze(-1)
bboxes = bboxes[dim0_inds, max_inds].reshape(batch_size, -1, 4)

# get nms params
Expand Down
4 changes: 2 additions & 2 deletions mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def cascade_roi_head__simple_test(ctx, self, x, proposals, img_metas,
return det_bboxes, det_labels
else:
batch_index = torch.arange(
det_bboxes.size(0), device=det_bboxes.device).float().view(
-1, 1, 1).expand(det_bboxes.size(0), det_bboxes.size(1), 1)
det_bboxes.size(0),
device=det_bboxes.device).float().view(-1, 1, 1)
rois = det_bboxes[..., :4]
mask_rois = torch.cat([batch_index, rois], dim=-1)
mask_rois = mask_rois.view(-1, 5)
Expand Down