diff --git a/export.py b/export.py index 0fba54142d..cf918aa42b 100644 --- a/export.py +++ b/export.py @@ -146,7 +146,7 @@ if opt.grid: if opt.end2end: print('\nStarting export end2end onnx model for %s...' % 'TensorRT' if opt.max_wh is None else 'onnxruntime') - model = End2End(model,opt.topk_all,opt.iou_thres,opt.conf_thres,opt.max_wh,device) + model = End2End(model,opt.topk_all,opt.iou_thres,opt.conf_thres,opt.max_wh,device,len(labels)) if opt.end2end and opt.max_wh is None: output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes'] shapes = [opt.batch_size, 1, opt.batch_size, opt.topk_all, 4, diff --git a/models/experimental.py b/models/experimental.py index 3fa5c12e31..735d7aa0eb 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -158,7 +158,7 @@ def symbolic(g, class ONNX_ORT(nn.Module): '''onnx module with ONNX-Runtime NMS operation.''' - def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=640, device=None): + def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=640, device=None, n_classes=80): super().__init__() self.device = device if device else torch.device("cpu") self.max_obj = torch.tensor([max_obj]).to(device) @@ -168,12 +168,17 @@ def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=640, de self.convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]], dtype=torch.float32, device=self.device) + self.n_classes=n_classes def forward(self, x): boxes = x[:, :, :4] conf = x[:, :, 4:5] scores = x[:, :, 5:] - scores *= conf + if self.n_classes == 1: + scores = conf # for models with one class, cls_loss is 0 and cls_conf is always 0.5, + # so there is no need to multiplicate. + else: + scores *= conf # conf = obj_conf * cls_conf boxes @= self.convert_matrix max_score, category_id = scores.max(2, keepdim=True) dis = category_id.float() * self.max_wh @@ -189,7 +194,7 @@ def forward(self, x): class ONNX_TRT(nn.Module): '''onnx module with TensorRT NMS operation.''' - def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None): + def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None, n_classes=80): super().__init__() assert max_wh is None self.device = device if device else torch.device('cpu') @@ -200,12 +205,17 @@ def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,d self.plugin_version = '1' self.score_activation = 0 self.score_threshold = score_thres + self.n_classes=n_classes def forward(self, x): boxes = x[:, :, :4] conf = x[:, :, 4:5] scores = x[:, :, 5:] - scores *= conf + if self.n_classes == 1: + scores = conf # for models with one class, cls_loss is 0 and cls_conf is always 0.5, + # so there is no need to multiplicate. + else: + scores *= conf # conf = obj_conf * cls_conf num_det, det_boxes, det_scores, det_classes = TRT_NMS.apply(boxes, scores, self.background_class, self.box_coding, self.iou_threshold, self.max_obj, self.plugin_version, self.score_activation, @@ -215,14 +225,14 @@ def forward(self, x): class End2End(nn.Module): '''export onnx or tensorrt model with NMS operation.''' - def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None, device=None): + def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None, device=None, n_classes=80): super().__init__() device = device if device else torch.device('cpu') assert isinstance(max_wh,(int)) or max_wh is None self.model = model.to(device) self.model.model[-1].end2end = True self.patch_model = ONNX_TRT if max_wh is None else ONNX_ORT - self.end2end = self.patch_model(max_obj, iou_thres, score_thres, max_wh, device) + self.end2end = self.patch_model(max_obj, iou_thres, score_thres, max_wh, device, n_classes) self.end2end.eval() def forward(self, x):