From cc3378d8b44ba06dae5383c1522960bca27e38af Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 15 Apr 2022 23:09:16 +0800 Subject: [PATCH] remove expand from mmdet rewriter --- .../mmdet/core/post_processing/bbox_nms.py | 8 +++----- mmdeploy/codebase/mmdet/deploy/utils.py | 15 ++++----------- .../models/dense_heads/base_dense_head.py | 8 ++++---- .../mmdet/models/dense_heads/gfl_head.py | 8 ++++---- .../mmdet/models/dense_heads/rpn_head.py | 8 ++++---- .../mmdet/models/dense_heads/yolo_head.py | 19 +++++++++---------- .../mmdet/models/roi_heads/bbox_head.py | 3 +-- .../models/roi_heads/cascade_roi_head.py | 4 ++-- 8 files changed, 31 insertions(+), 42 deletions(-) diff --git a/mmdeploy/codebase/mmdet/core/post_processing/bbox_nms.py b/mmdeploy/codebase/mmdet/core/post_processing/bbox_nms.py index ee7a1403d7..50d6c7caea 100644 --- a/mmdeploy/codebase/mmdet/core/post_processing/bbox_nms.py +++ b/mmdeploy/codebase/mmdet/core/post_processing/bbox_nms.py @@ -66,7 +66,7 @@ def select_nms_index(scores: torch.Tensor, _, topk_inds = batched_dets[:, :, -1].sort(dim=1, descending=True) topk_batch_inds = torch.arange( batch_size, dtype=topk_inds.dtype, - device=topk_inds.device).view(-1, 1).expand_as(topk_inds) + device=topk_inds.device).view(-1, 1) batched_dets = batched_dets[topk_batch_inds, topk_inds, ...] batched_labels = batched_labels[topk_batch_inds, topk_inds, ...] @@ -96,8 +96,7 @@ def _multiclass_nms(boxes: Tensor, if pre_top_k > 0: max_scores, _ = scores.max(-1) _, topk_inds = max_scores.topk(pre_top_k) - batch_inds = torch.arange(batch_size).view( - -1, 1).expand_as(topk_inds).long() + batch_inds = torch.arange(batch_size).view(-1, 1).long() boxes = boxes[batch_inds, topk_inds, :] scores = scores[batch_inds, topk_inds, :] @@ -298,8 +297,7 @@ def multiclass_nms__torchscript(ctx, if pre_top_k > 0: max_scores, _ = scores.max(-1) _, topk_inds = max_scores.topk(pre_top_k) - batch_inds = torch.arange(batch_size).view( - -1, 1).expand_as(topk_inds).long() + batch_inds = torch.arange(batch_size).view(-1, 1).long() boxes = boxes[batch_inds, topk_inds, ...] scores = scores[batch_inds, topk_inds, :] num_boxes = scores.shape[1] diff --git a/mmdeploy/codebase/mmdet/deploy/utils.py b/mmdeploy/codebase/mmdet/deploy/utils.py index 5fd5b7ab78..690934ad7e 100644 --- a/mmdeploy/codebase/mmdet/deploy/utils.py +++ b/mmdeploy/codebase/mmdet/deploy/utils.py @@ -113,18 +113,11 @@ def pad_with_value(x: Tensor, Returns: Tensor: Padded tensor. """ - num_dims = len(x.shape) - pad_slice = (slice(None, None, None), ) * num_dims - pad_slice = pad_slice[:pad_dim] + (slice(0, 1, - 1), ) + pad_slice[pad_dim + 1:] - repeat_size = [1] * num_dims - repeat_size[pad_dim] = pad_size - - x_pad = x.__getitem__(pad_slice) + x_shape = list(x.shape) + pad_shape = x_shape[:pad_dim] + [pad_size] + x_shape[pad_dim + 1:] + x_pad = x.new_zeros(pad_shape) if pad_value is not None: - x_pad = x_pad * 0 + pad_value - - x_pad = x_pad.repeat(*repeat_size) + x_pad = x_pad + pad_value x = torch.cat([x, x_pad], dim=pad_dim) return x diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py index cd224a4548..f5806f67c2 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py @@ -66,6 +66,7 @@ def base_dense_head__get_bbox(ctx, featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] mlvl_priors = self.prior_generator.grid_priors( featmap_sizes, dtype=bbox_preds[0].dtype, device=bbox_preds[0].device) + mlvl_priors = [priors.unsqueeze(0) for priors in mlvl_priors] mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)] mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)] @@ -108,7 +109,6 @@ def base_dense_head__get_bbox(ctx, bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4) if not is_dynamic_flag: priors = priors.data - priors = priors.expand(batch_size, -1, priors.size(-1)) if pre_topk > 0: priors = pad_with_value_if_necessary(priors, 1, pre_topk) bbox_pred = pad_with_value_if_necessary(bbox_pred, 1, pre_topk) @@ -128,9 +128,9 @@ def base_dense_head__get_bbox(ctx, max_scores, _ = nms_pre_score[..., :-1].max(-1) _, topk_inds = max_scores.topk(pre_topk) batch_inds = torch.arange( - batch_size, - device=bbox_pred.device).view(-1, 1).expand_as(topk_inds) - priors = priors[batch_inds, topk_inds, :] + batch_size, device=bbox_pred.device).unsqueeze(-1) + prior_inds = batch_inds.new_zeros((1, 1)) + priors = priors[prior_inds, topk_inds, :] bbox_pred = bbox_pred[batch_inds, topk_inds, :] scores = scores[batch_inds, topk_inds, :] if with_score_factors: diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py index 8dba8b5666..6827c75b7f 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py @@ -63,6 +63,7 @@ def gfl_head__get_bbox(ctx, featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] mlvl_priors = self.prior_generator.grid_priors( featmap_sizes, dtype=bbox_preds[0].dtype, device=bbox_preds[0].device) + mlvl_priors = [priors.unsqueeze(0) for priors in mlvl_priors] mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)] mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)] @@ -110,7 +111,6 @@ def gfl_head__get_bbox(ctx, bbox_pred.permute(0, 2, 3, 1)) * stride[0] if not is_dynamic_flag: priors = priors.data - priors = priors.expand(batch_size, -1, priors.size(-1)) if pre_topk > 0: if with_score_factors: nms_pre_score = nms_pre_score * score_factors @@ -130,9 +130,9 @@ def gfl_head__get_bbox(ctx, max_scores, _ = nms_pre_score[..., :-1].max(-1) _, topk_inds = max_scores.topk(pre_topk) batch_inds = torch.arange( - batch_size, - device=bbox_pred.device).view(-1, 1).expand_as(topk_inds) - priors = priors[batch_inds, topk_inds, :] + batch_size, device=bbox_pred.device).unsqueeze(-1) + prior_inds = batch_inds.new_zeros((1, 1)) + priors = priors[prior_inds, topk_inds, :] bbox_pred = bbox_pred[batch_inds, topk_inds, :] scores = scores[batch_inds, topk_inds, :] if with_score_factors: diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py index 5523c44b47..fdd33785b4 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py @@ -94,7 +94,7 @@ def rpn_head__get_bboxes(ctx, if not is_dynamic_flag: anchors = anchors.data - anchors = anchors.expand_as(bbox_pred) + anchors = anchors.unsqueeze(0) # topk in tensorrt does not support shape 0: _, topk_inds = scores.squeeze(2).topk(pre_topk) - batch_inds = torch.arange( - batch_size, device=device).view(-1, 1).expand_as(topk_inds) - anchors = anchors[batch_inds, topk_inds, :] + batch_inds = torch.arange(batch_size, device=device).unsqueeze(-1) + prior_inds = topk_inds.new_zeros((1, 1)) + anchors = anchors[prior_inds, topk_inds, :] bbox_pred = bbox_pred[batch_inds, topk_inds, :] scores = scores[batch_inds, topk_inds, :] mlvl_valid_bboxes.append(bbox_pred) diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py index cb57cddfc3..11acb61f61 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py @@ -83,8 +83,7 @@ def yolov3_head__get_bboxes(ctx, # use static anchor if input shape is static if not is_dynamic_flag: multi_lvl_anchor = multi_lvl_anchor.data - multi_lvl_anchor = multi_lvl_anchor.unsqueeze(0).expand_as( - pred_map_boxes) + multi_lvl_anchor = multi_lvl_anchor.unsqueeze(0) bbox_pred = self.bbox_coder.decode(multi_lvl_anchor, pred_map_boxes, stride) # conf and cls @@ -100,8 +99,7 @@ def yolov3_head__get_bboxes(ctx, if pre_topk > 0: _, topk_inds = conf_pred.topk(pre_topk) batch_inds = torch.arange( - batch_size, device=device).view(-1, - 1).expand_as(topk_inds).long() + batch_size, device=device).unsqueeze(-1).long() # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501 transformed_inds = (bbox_pred.shape[1] * batch_inds + topk_inds) bbox_pred = bbox_pred.reshape(-1, 4)[transformed_inds, :].reshape( @@ -129,14 +127,15 @@ def yolov3_head__get_bboxes(ctx, # follow original pipeline of YOLOv3 if confidence_threshold > 0: - mask = (batch_mlvl_conf_scores >= confidence_threshold).float() - batch_mlvl_conf_scores *= mask + mask = batch_mlvl_conf_scores >= confidence_threshold + batch_mlvl_conf_scores = batch_mlvl_conf_scores.where( + mask, batch_mlvl_conf_scores.new_zeros(1)) if score_threshold > 0: - mask = (batch_mlvl_scores > score_threshold).float() - batch_mlvl_scores *= mask + mask = batch_mlvl_scores > score_threshold + batch_mlvl_scores = batch_mlvl_scores.where( + mask, batch_mlvl_scores.new_zeros(1)) - batch_mlvl_conf_scores = batch_mlvl_conf_scores.unsqueeze(2).expand_as( - batch_mlvl_scores) + batch_mlvl_conf_scores = batch_mlvl_conf_scores.unsqueeze(2) batch_mlvl_scores = batch_mlvl_scores * batch_mlvl_conf_scores if with_nms: diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/bbox_head.py b/mmdeploy/codebase/mmdet/models/roi_heads/bbox_head.py index 65473bc70e..9f596d340c 100644 --- a/mmdeploy/codebase/mmdet/models/roi_heads/bbox_head.py +++ b/mmdeploy/codebase/mmdet/models/roi_heads/bbox_head.py @@ -98,8 +98,7 @@ def bbox_head__get_bboxes(ctx, # only keep boxes with the max scores max_inds = scores.reshape(-1, self.num_classes).argmax(1, keepdim=True) bboxes = bboxes.reshape(-1, self.num_classes, 4) - dim0_inds = torch.arange( - bboxes.shape[0], device=device).view(-1, 1).expand_as(max_inds) + dim0_inds = torch.arange(bboxes.shape[0], device=device).unsqueeze(-1) bboxes = bboxes[dim0_inds, max_inds].reshape(batch_size, -1, 4) # get nms params diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py b/mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py index bf3c24d5fb..107439ac6e 100644 --- a/mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py +++ b/mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py @@ -84,8 +84,8 @@ def cascade_roi_head__simple_test(ctx, self, x, proposals, img_metas, return det_bboxes, det_labels else: batch_index = torch.arange( - det_bboxes.size(0), device=det_bboxes.device).float().view( - -1, 1, 1).expand(det_bboxes.size(0), det_bboxes.size(1), 1) + det_bboxes.size(0), + device=det_bboxes.device).float().view(-1, 1, 1) rois = det_bboxes[..., :4] mask_rois = torch.cat([batch_index, rois], dim=-1) mask_rois = mask_rois.view(-1, 5)