From 1c717aa80085193fb18d2d110ef9eb40d40b82e8 Mon Sep 17 00:00:00 2001 From: "mikel.brostrom" Date: Sun, 28 May 2023 09:03:33 +0200 Subject: [PATCH 1/2] add @torch.no_grad() to all ReID related inference methods --- boxmot/botsort/bot_sort.py | 3 +++ boxmot/deepocsort/ocsort.py | 3 ++- boxmot/strongsort/strong_sort.py | 1 + 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/boxmot/botsort/bot_sort.py b/boxmot/botsort/bot_sort.py index 3d92a3edee..a7727aa33c 100644 --- a/boxmot/botsort/bot_sort.py +++ b/boxmot/botsort/bot_sort.py @@ -7,6 +7,7 @@ from .gmc import GMC from .basetrack import BaseTrack, TrackState from .kalman_filter import KalmanFilter +import torch # from fast_reid.fast_reid_interfece import FastReIDInterface @@ -263,6 +264,7 @@ def __init__(self, self.proximity_thresh = proximity_thresh self.appearance_thresh = appearance_thresh self.match_thresh = match_thresh + print(device) self.model = ReIDDetectMultiBackend(weights=model_weights, device=device, fp16=fp16) @@ -482,6 +484,7 @@ def _xywh_to_xyxy(self, bbox_xywh): y2 = min(int(y + h / 2), self.height - 1) return x1, y1, x2, y2 + @torch.no_grad() def _get_features(self, bbox_xywh, ori_img): im_crops = [] for box in bbox_xywh: diff --git a/boxmot/deepocsort/ocsort.py b/boxmot/deepocsort/ocsort.py index f39bf9d824..5c5afe849f 100644 --- a/boxmot/deepocsort/ocsort.py +++ b/boxmot/deepocsort/ocsort.py @@ -3,7 +3,7 @@ """ from __future__ import print_function - +import torch import numpy as np from .association import * from .cmc import CMCComputer @@ -528,6 +528,7 @@ def _xywh_to_xyxy(self, bbox_xywh): y2 = min(int(y + h / 2), self.height - 1) return x1, y1, x2, y2 + @torch.no_grad() def _get_features(self, bbox_xyxy, ori_img): im_crops = [] for box in bbox_xyxy: diff --git a/boxmot/strongsort/strong_sort.py b/boxmot/strongsort/strong_sort.py index b578903c15..a1b5c49f7a 100644 --- a/boxmot/strongsort/strong_sort.py +++ b/boxmot/strongsort/strong_sort.py @@ -128,6 +128,7 @@ def _xyxy_to_tlwh(self, bbox_xyxy): h = int(y2 - y1) return t, l, w, h + @torch.no_grad() def _get_features(self, bbox_xywh, ori_img): im_crops = [] for box in bbox_xywh: From 30f390a003070c770fc5c11d86d7c171ef2eaad3 Mon Sep 17 00:00:00 2001 From: "mikel.brostrom" Date: Sun, 28 May 2023 09:38:03 +0200 Subject: [PATCH 2/2] delete dubug print --- boxmot/botsort/bot_sort.py | 1 - 1 file changed, 1 deletion(-) diff --git a/boxmot/botsort/bot_sort.py b/boxmot/botsort/bot_sort.py index a7727aa33c..14df89dcdc 100644 --- a/boxmot/botsort/bot_sort.py +++ b/boxmot/botsort/bot_sort.py @@ -264,7 +264,6 @@ def __init__(self, self.proximity_thresh = proximity_thresh self.appearance_thresh = appearance_thresh self.match_thresh = match_thresh - print(device) self.model = ReIDDetectMultiBackend(weights=model_weights, device=device, fp16=fp16)