Skip to content

Commit

Permalink
Keep consistency of transform (#102)
Browse files Browse the repository at this point in the history
* Keep consistency of transform

* Fix jit scripting

* Update notebooks

* Fix unit-test of torchscript
  • Loading branch information
zhiqwang authored Apr 30, 2021
1 parent f4bb74c commit 0dfc40f
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 84 deletions.
30 changes: 16 additions & 14 deletions notebooks/export-onnx-inference-onnxruntime.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,13 @@
"metadata": {},
"outputs": [],
"source": [
"img_one = get_image_from_url(\"https://gitee.com/zhiqwang/yolov5-rt-stack/raw/master/test/assets/bus.jpg\")\n",
"# img_one = get_image_from_url(\"https://gitee.com/zhiqwang/yolov5-rt-stack/raw/master/test/assets/bus.jpg\")\n",
"img_one = cv2.imread('../test/assets/bus.jpg')\n",
"img_one = read_image_to_tensor(img_one, is_half=False)\n",
"img_one = img_one.to(device)\n",
"\n",
"img_two = get_image_from_url(\"https://gitee.com/zhiqwang/yolov5-rt-stack/raw/master/test/assets/bus.jpg\")\n",
"# img_two = get_image_from_url(\"https://gitee.com/zhiqwang/yolov5-rt-stack/raw/master/test/assets/zidane.jpg\")\n",
"img_two = cv2.imread('../test/assets/zidane.jpg')\n",
"img_two = read_image_to_tensor(img_two, is_half=False)\n",
"img_two = img_two.to(device)\n",
"\n",
Expand Down Expand Up @@ -100,8 +102,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 3.72 s, sys: 132 ms, total: 3.86 s\n",
"Wall time: 109 ms\n"
"CPU times: user 4.04 s, sys: 68 ms, total: 4.11 s\n",
"Wall time: 117 ms\n"
]
}
],
Expand All @@ -119,10 +121,10 @@
{
"data": {
"text/plain": [
"tensor([[ 52.1603, 384.9539, 235.4333, 899.1226],\n",
" [223.7285, 407.0463, 346.9296, 862.0854],\n",
" [ 8.5867, 227.5113, 805.2753, 765.3226],\n",
" [675.7438, 394.1103, 811.3925, 869.5128]])"
"tensor([[ 53.9381, 389.6684, 238.1329, 898.0309],\n",
" [223.5657, 409.1237, 344.0657, 861.4382],\n",
" [ 19.7466, 226.0309, 795.1536, 762.1879],\n",
" [675.1501, 390.4896, 813.6778, 874.6315]])"
]
},
"execution_count": 7,
Expand All @@ -142,7 +144,7 @@
{
"data": {
"text/plain": [
"tensor([0.8993, 0.8671, 0.8034, 0.8005])"
"tensor([0.8762, 0.8721, 0.8337, 0.8301])"
]
},
"execution_count": 8,
Expand Down Expand Up @@ -230,11 +232,11 @@
" anchor_grid = torch.as_tensor(anchor_grid, dtype=dtype, device=device)\n",
"/data/wangzq/yolov5-rt-stack/yolort/models/anchor_utils.py:77: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n",
" shifts = shifts - torch.tensor(0.5, dtype=shifts.dtype, device=device)\n",
"/data/wangzq/yolov5-rt-stack/yolort/models/box_head.py:363: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
"/data/wangzq/yolov5-rt-stack/yolort/models/box_head.py:362: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
" for idx in range(batch_size): # image idx, image inference\n",
"/data/wangzq/yolov5-rt-stack/yolort/models/transform.py:287: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n",
"/data/wangzq/yolov5-rt-stack/yolort/models/transform.py:282: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n",
" for s, s_orig in zip(new_size, original_size)\n",
"/data/wangzq/yolov5-rt-stack/yolort/models/transform.py:287: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
"/data/wangzq/yolov5-rt-stack/yolort/models/transform.py:282: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" for s, s_orig in zip(new_size, original_size)\n",
"/usr/local/lib/python3.6/dist-packages/torch/onnx/symbolic_opset9.py:2378: UserWarning: Exporting aten::index operator of advanced indexing in opset 11 is achieved by combination of multiple ONNX operators, including Reshape, Transpose, Concat, and Gather. If indices include negative values, the exported graph will produce incorrect results.\n",
" \"If indices include negative values, the exported graph will produce incorrect results.\")\n",
Expand Down Expand Up @@ -426,8 +428,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2.3 s, sys: 0 ns, total: 2.3 s\n",
"Wall time: 62.3 ms\n"
"CPU times: user 2.5 s, sys: 16 ms, total: 2.51 s\n",
"Wall time: 77.6 ms\n"
]
}
],
Expand Down
53 changes: 45 additions & 8 deletions notebooks/export-relay-inference-tvm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,11 @@
" anchor_grid = torch.as_tensor(anchor_grid, dtype=dtype, device=device)\n",
"/data/wangzq/yolov5-rt-stack/yolort/models/anchor_utils.py:77: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n",
" shifts = shifts - torch.tensor(0.5, dtype=shifts.dtype, device=device)\n",
"/data/wangzq/yolov5-rt-stack/yolort/models/box_head.py:363: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
"/data/wangzq/yolov5-rt-stack/yolort/models/box_head.py:362: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
" for idx in range(batch_size): # image idx, image inference\n",
"/data/wangzq/yolov5-rt-stack/yolort/models/transform.py:287: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n",
"/data/wangzq/yolov5-rt-stack/yolort/models/transform.py:282: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n",
" for s, s_orig in zip(new_size, original_size)\n",
"/data/wangzq/yolov5-rt-stack/yolort/models/transform.py:287: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
"/data/wangzq/yolov5-rt-stack/yolort/models/transform.py:282: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" for s, s_orig in zip(new_size, original_size)\n"
]
}
Expand Down Expand Up @@ -203,8 +203,8 @@
"source": [
"from yolort.utils import get_image_from_url\n",
"\n",
"img = get_image_from_url(\"https://gitee.com/zhiqwang/yolov5-rt-stack/raw/master/test/assets/bus.jpg\")\n",
"# img = cv2.imread('../test/assets/bus.jpg')\n",
"# img = get_image_from_url(\"https://gitee.com/zhiqwang/yolov5-rt-stack/raw/master/test/assets/bus.jpg\")\n",
"img = cv2.imread('../test/assets/bus.jpg')\n",
"\n",
"img = img.astype(\"float32\")\n",
"img = cv2.resize(img, (in_size, in_size))\n",
Expand Down Expand Up @@ -383,8 +383,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 640 ms, sys: 392 ms, total: 1.03 s\n",
"Wall time: 25.9 ms\n"
"CPU times: user 456 ms, sys: 964 ms, total: 1.42 s\n",
"Wall time: 36.4 ms\n"
]
}
],
Expand Down Expand Up @@ -429,7 +429,44 @@
" else:\n",
" break\n",
"\n",
"print(\"Get {} valid boxes\".format(len(valid_boxes)))"
"print(f\"Get {len(valid_boxes)} valid boxes\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Varify the Inference Output on TVM backend"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"with torch.no_grad():\n",
" torch_res = model(torch.from_numpy(img))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Exported model has been tested with TVM Runtime, and the result looks good!\n"
]
}
],
"source": [
"for i in range(len(torch_res)):\n",
" torch.testing.assert_allclose(torch_res[i], tvm_res[i].asnumpy(), rtol=1e-04, atol=1e-07)\n",
"\n",
"print(\"Exported model has been tested with TVM Runtime, and the result looks good!\")"
]
}
],
Expand Down
68 changes: 35 additions & 33 deletions notebooks/inference-pytorch-export-libtorch.ipynb

Large diffs are not rendered by default.

24 changes: 12 additions & 12 deletions test/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def test_yolov5s_script(self):

out = model(x)
out_script = scripted_model(x)
self.assertTrue(out[0]["scores"].equal(out_script[0]["scores"]))
self.assertTrue(out[0]["labels"].equal(out_script[0]["labels"]))
self.assertTrue(out[0]["boxes"].equal(out_script[0]["boxes"]))
self.assertTrue(out[0]["scores"].equal(out_script[1][0]["scores"]))
self.assertTrue(out[0]["labels"].equal(out_script[1][0]["labels"]))
self.assertTrue(out[0]["boxes"].equal(out_script[1][0]["boxes"]))

def test_yolov5m_script(self):
model = yolov5m(pretrained=True)
Expand All @@ -33,9 +33,9 @@ def test_yolov5m_script(self):

out = model(x)
out_script = scripted_model(x)
self.assertTrue(out[0]["scores"].equal(out_script[0]["scores"]))
self.assertTrue(out[0]["labels"].equal(out_script[0]["labels"]))
self.assertTrue(out[0]["boxes"].equal(out_script[0]["boxes"]))
self.assertTrue(out[0]["scores"].equal(out_script[1][0]["scores"]))
self.assertTrue(out[0]["labels"].equal(out_script[1][0]["labels"]))
self.assertTrue(out[0]["boxes"].equal(out_script[1][0]["boxes"]))

def test_yolov5l_script(self):
model = yolov5l(pretrained=True)
Expand All @@ -48,9 +48,9 @@ def test_yolov5l_script(self):

out = model(x)
out_script = scripted_model(x)
self.assertTrue(out[0]["scores"].equal(out_script[0]["scores"]))
self.assertTrue(out[0]["labels"].equal(out_script[0]["labels"]))
self.assertTrue(out[0]["boxes"].equal(out_script[0]["boxes"]))
self.assertTrue(out[0]["scores"].equal(out_script[1][0]["scores"]))
self.assertTrue(out[0]["labels"].equal(out_script[1][0]["labels"]))
self.assertTrue(out[0]["boxes"].equal(out_script[1][0]["boxes"]))

def test_yolotr_script(self):
model = yolotr(pretrained=True)
Expand All @@ -63,6 +63,6 @@ def test_yolotr_script(self):

out = model(x)
out_script = scripted_model(x)
self.assertTrue(out[0]["scores"].equal(out_script[0]["scores"]))
self.assertTrue(out[0]["labels"].equal(out_script[0]["labels"]))
self.assertTrue(out[0]["boxes"].equal(out_script[0]["boxes"]))
self.assertTrue(out[0]["scores"].equal(out_script[1][0]["scores"]))
self.assertTrue(out[0]["labels"].equal(out_script[1][0]["labels"]))
self.assertTrue(out[0]["boxes"].equal(out_script[1][0]["boxes"]))
15 changes: 5 additions & 10 deletions yolort/models/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __repr__(self):
return str(self.tensors)


class GeneralizedYOLOTransform(nn.Module):
class YOLOTransform(nn.Module):
"""
Performs input / target transformation before feeding the data to a GeneralizedRCNN
model.
Expand Down Expand Up @@ -146,22 +146,17 @@ def resize(

def postprocess(
self,
result: Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]],
result: List[Dict[str, Tensor]],
image_shapes: List[Tuple[int, int]],
original_image_sizes: List[Tuple[int, int]],
) -> List[Dict[str, Tensor]]:

if torch.jit.is_scripting():
predictions = result[1]
else:
predictions = result

for i, (pred, im_s, o_im_s) in enumerate(zip(predictions, image_shapes, original_image_sizes)):
for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
boxes = pred["boxes"]
boxes = resize_boxes(boxes, im_s, o_im_s)
predictions[i]["boxes"] = boxes
result[i]["boxes"] = boxes

return predictions
return result


def nested_tensor_from_tensor_list(tensor_list: List[Tensor], size_divisible: int = 32):
Expand Down
19 changes: 12 additions & 7 deletions yolort/models/yolo_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pytorch_lightning import LightningModule

from . import yolo
from .transform import GeneralizedYOLOTransform
from .transform import YOLOTransform
from ._utils import _evaluate_iou
from ..data import DetectionDataModule, DataPipeline, COCOEvaluator

Expand Down Expand Up @@ -49,7 +49,7 @@ def __init__(
self.model = yolo.__dict__[arch](
pretrained=pretrained, progress=progress, num_classes=num_classes, **kwargs)

self.transform = GeneralizedYOLOTransform(min_size, max_size)
self.transform = YOLOTransform(min_size, max_size)

self._data_pipeline = None

Expand All @@ -65,7 +65,7 @@ def _forward_impl(
self,
inputs: List[Tensor],
targets: Optional[List[Dict[str, Tensor]]] = None,
) -> List[Dict[str, Tensor]]:
) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
"""
Args:
inputs (list[Tensor]): images to be processed
Expand Down Expand Up @@ -102,13 +102,18 @@ def _forward_impl(
losses = outputs
else:
# Rescale coordinate
detections = self.transform.postprocess(outputs, samples.image_sizes, original_image_sizes)
if torch.jit.is_scripting():
result = outputs[1]
else:
result = outputs

detections = self.transform.postprocess(result, samples.image_sizes, original_image_sizes)

if torch.jit.is_scripting():
if not self._has_warned:
warnings.warn("YOLOModule always returns Detections in scripting.")
warnings.warn("YOLOModule always returns a (Losses, Detections) tuple in scripting.")
self._has_warned = True
return detections
return losses, detections
else:
return self.eager_outputs(losses, detections)

Expand All @@ -127,7 +132,7 @@ def forward(
self,
inputs: List[Tensor],
targets: Optional[List[Dict[str, Tensor]]] = None,
) -> List[Dict[str, Tensor]]:
) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
"""
This exists since PyTorchLightning forward are used for inference only (separate from
``training_step``). We keep ``targets`` here for Backward Compatible.
Expand Down

0 comments on commit 0dfc40f

Please sign in to comment.