Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mmdet TensorRT support #1042

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ jobs:
pip install mmengine==0.7.3
pip install mmcv==2.0.0 -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.13.0/index.html
pip install mmdet==3.0.0

- name: Install MMDeploy(1.3.1)
run: >
pip install openmim
mim install mmdeploy==1.3.1
fcakyon marked this conversation as resolved.
Show resolved Hide resolved

mim install mmdeploy-runtime==1.3.1

- name: Install YOLOv5(7.0.13)
run: >
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/package_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ jobs:
pip install mmcv==2.0.0 -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.13.0/index.html
pip install mmdet==3.0.0

- name: Install MMDeploy(1.3.1)
run: >
pip install openmim
mim install mmdeploy==1.3.1
fcakyon marked this conversation as resolved.
Show resolved Hide resolved
mim install mmdeploy-runtime==1.3.1

- name: Install YOLOv5(7.0.13)
run: >
pip install yolov5==7.0.13
Expand Down
6 changes: 5 additions & 1 deletion sahi/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,17 @@ def __init__(
category_remapping: Optional[Dict] = None,
load_at_init: bool = True,
image_size: int = None,
deploy_config_path: Optional[str] = None,
):
"""
Init object detection/instance segmentation model.
Args:
model_path: str
Path for the instance segmentation model weight
Path for the instance segmentation model weight, .engine file if it's tensorrt
config_path: str
Path for the mmdetection instance segmentation model config file
deploy_config_path: str
Path for the mmdetection detection/instance segmentation deployment config file
device: str
Torch device, "cpu" or "cuda"
mask_threshold: float
Expand All @@ -47,6 +50,7 @@ def __init__(
"""
self.model_path = model_path
self.config_path = config_path
self.deploy_config_path = deploy_config_path
self.model = None
self.device = device
self.mask_threshold = mask_threshold
Expand Down
76 changes: 71 additions & 5 deletions sahi/models/mmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,68 @@


try:
check_requirements(["torch", "mmdet", "mmcv", "mmengine"])
check_requirements(["torch", "mmdet", "mmcv", "mmengine", "mmdeploy"])

import torch
from mmdeploy.apis.utils import build_task_processor
from mmdeploy.utils import get_input_shape, load_config
from mmdet.apis.det_inferencer import DetInferencer
from mmdet.utils import ConfigType
from mmengine.dataset import Compose
from mmengine.infer.infer import ModelType

class DetTrtInferencerWrapper:
def __init__(self, deploy_cfg: str, model_cfg: str, engine_file: str, device: Optional[str] = None) -> None:
"""
Emulate DetInferencer(images) for TensorRT model
Args:
deploy_cfg: str
Deployment cfg file, for example, detection_tensorrt-fp16_static-640x640.py.
model_cfg: str
Model cfg file, for example, rtmdet_l_8xb32-300e_coco.py
engine_file: str
Serialized TensorRT file, i.e end2end.engine
"""
deploy_cfg, model_cfg = load_config(
deploy_cfg,
model_cfg,
)
self.cfg = model_cfg
self.task_processor = build_task_processor(model_cfg, deploy_cfg, device)
self.model = self.task_processor.build_backend_model(
[engine_file], self.task_processor.update_data_preprocessor
)
self.input_shape = get_input_shape(deploy_cfg)
self.output_names = set(self.model.output_names)

def __call__(self, images: List[np.ndarray], batch_size: int = 1) -> dict:
"""
Emulate DetInferencerWrapper(images) for TensorRT model
Args:
images: list of np.ndarray
A list of numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
batch_size: int
Inference batch size. Defaults to 1.
"""

def _tensor_to_list(tensor):
return tensor.tolist() if tensor.numel() > 0 else []

results_dict = {"predictions": [], "visualization": []}
for image in images:
model_inputs, _ = self.task_processor.create_input(image, self.input_shape)
with torch.no_grad():
results = self.model.test_step(model_inputs)[0]
predictions = [
{
"scores": _tensor_to_list(results.pred_instances.scores.cpu()),
"labels": _tensor_to_list(results.pred_instances.labels.cpu()),
"bboxes": _tensor_to_list(results.pred_instances.bboxes.cpu()),
}
]
results_dict["predictions"].extend(predictions)
return results_dict

class DetInferencerWrapper(DetInferencer):
def __init__(
self,
Expand Down Expand Up @@ -94,6 +149,7 @@ def __init__(
model_path: Optional[str] = None,
model: Optional[Any] = None,
config_path: Optional[str] = None,
deploy_config_path: Optional[str] = None,
device: Optional[str] = None,
mask_threshold: float = 0.5,
confidence_threshold: float = 0.3,
Expand All @@ -108,6 +164,8 @@ def __init__(

self.scope = scope
self.image_size = image_size
# Check if tensorrt deploy cfg is defined
self.trt = deploy_config_path is not None

super().__init__(
model_path,
Expand All @@ -120,6 +178,7 @@ def __init__(
category_remapping,
load_at_init,
image_size,
deploy_config_path,
)

def check_dependencies(self):
Expand All @@ -129,11 +188,15 @@ def load_model(self):
"""
Detection model is initialized and set to self.model.
"""

# create model
model = DetInferencerWrapper(
self.config_path, self.model_path, device=self.device, scope=self.scope, image_size=self.image_size
)
if self.trt:
model = DetTrtInferencerWrapper(
self.deploy_config_path, self.config_path, self.model_path, device=self.device.type
)
else:
model = DetInferencerWrapper(
self.config_path, self.model_path, device=self.device, scope=self.scope, image_size=self.image_size
)

self.set_model(model)

Expand All @@ -148,6 +211,9 @@ def set_model(self, model: Any):
# set self.model
self.model = model

if self.trt and not self.category_mapping:
raise ValueError("TensorRT model needs category_mapping defined and passed to the constructor")

# set category_mapping
if not self.category_mapping:
category_mapping = {str(ind): category_name for ind, category_name in enumerate(self.category_names)}
Expand Down
Loading