From aa536ec2fd890a8da88ed790f31dcae20bef0b13 Mon Sep 17 00:00:00 2001 From: qiufeng <44188071+wutongshenqiu@users.noreply.github.com> Date: Tue, 19 Apr 2022 11:31:56 +0800 Subject: [PATCH] [Feature] Add mmrazor support (#220) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Torchscript support (#159) * support torchscript * add nms * add torchscript configs and update deploy process and dump-info * typescript -> torchscript * add torchscript custom extension support * add ts custom ops again * support mmseg unet * [WIP] add optimizer for torchscript (#119) * add passes * add python api * Torchscript optimizer python api (#121) * add passes * add python api * use python api instead of executable * Merge Master, update optimizer (#151) * [Feature] add yolox ncnn (#29) * add yolox ncnn * add ncnn android performance of yolox * add ut * fix lint * fix None bugs for ncnn * test codecov * test codecov * add device * fix yapf * remove if-else for img shape * use channelshuffle optimize * change benchmark after channelshuffle * fix yapf * fix yapf * fuse continuous reshape * fix static shape deploy * fix code * drop pad * only static shape * fix static * fix docstring * Added mask overlay to output image, changed fprintf info messages to … (#55) * Added mask overlay to output image, changed fprintf info messages to stdout * Improved box filtering (filter area/score), make sure roi coordinates stay within bounds * clang-format * Support UNet in mmseg (#77) * Repeatdataset in train has no CLASSES & PALETTE * update result for unet * update docstring for mmdet * remove ppl for unet in docs * fix ort wrap about input type (#81) * Fix memleak (#86) * delete [] * fix build error when enble MMDEPLOY_ACTIVE_LEVEL * fix lint * [Doc] Nano benchmark and tutorial (#71) * add cls benchmark * add nano zh-cn benchmark and en tutorial * add device row * add doc path to index.rst * fix typo * [Fix] fix missing deploy_core (#80) * fix missing deploy_core * mv flag to demo * target link * [Docs] Fix links in Chinese doc (#84) * Fix docs in Chinese link * Fix links * Delete symbolic link and add links to html * delete files * Fix link * [Feature] Add docker files (#67) * add gpu and cpu dockerfile * fix lint * fix cpu docker and remove redundant * use pip instead * add build arg and readme * fix grammar * update readme * add chinese doc for dockerfile and add docker build to build.md * grammar * refine dockerfiles * add FAQs * update Dpplcv_DIR for SDK building * remove mmcls * add sdk demos * fix typo and lint * update FAQs * [Fix]fix check_env (#101) * fix check_env * update * Replace convert_syncbatchnorm in mmseg (#93) * replace convert_syncbatchnorm with revert_sync_batchnorm from mmcv * change logger * [Doc] Update FAQ for TensorRT (#96) * update FAQ * comment * [Docs]: Update doc for openvino installation (#102) * fix docs * fix docs * fix docs * fix mmcv version * fix docs * rm blank line * simplify non batch nms (#99) * [Enhacement] Allow test.py to save evaluation results (#108) * Add log file * Delete debug code * Rename logger * resolve comments * [Enhancement] Support mmocr v0.4+ (#115) * support mmocr v0.4+ * 0.4.0 -> 0.4.1 * fix onnxruntime wrapper for gpu inference (#123) * fix ncnn wrapper for ort-gpu * resolve comment * fix lint * Fix typo (#132) * lock mmcls version (#131) * [Enhancement] upgrade isort in pre-commit config (#141) * [Enhancement] upgrade isort in pre-commit config by refering to mmflow pr #87 * fix lint * remove .isort.cfg and put its known_third_party to setup.cfg * Fix ci for mmocr (#144) * fix mmocr unittests * remove useless * lock mmdet maximum version to 2.20 * pip install -U numpy * Fix capture_output (#125) Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com> Co-authored-by: Johannes L Co-authored-by: RunningLeon Co-authored-by: VVsssssk <88368822+VVsssssk@users.noreply.github.com> Co-authored-by: lvhan028 Co-authored-by: AllentDan <41138331+AllentDan@users.noreply.github.com> Co-authored-by: Yifan Zhou Co-authored-by: 杨培文 (Yang Peiwen) <915505626@qq.com> Co-authored-by: Semyon Bevzyuk * configs for all tasks * use torchvision roi align * remote unnecessary code * fix ut * fix ut * export * det dynamic * det dynamic * add ut * fix ut * add ut and docs * fix ut * skip torchscript ut if no ops available * add torchscript option to build.md * update benchmark and resolve comments * resolve conflicts * rename configs * fix mrcnn cuda test * remove useless * add version requirements to docs and comments to codes * enable empty image exporting for torchscript and accelerate ORT inference for MRCNN * rebase * update example for torchscript.md * update FAQs for torchscript.md * resolve comments * only use torchvision roi_align for torchscript * fix ut * use torchvision roi align when pool model is avg * resolve comments Co-authored-by: grimoire Co-authored-by: grimoire Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com> Co-authored-by: Johannes L Co-authored-by: RunningLeon Co-authored-by: VVsssssk <88368822+VVsssssk@users.noreply.github.com> Co-authored-by: lvhan028 Co-authored-by: Yifan Zhou Co-authored-by: 杨培文 (Yang Peiwen) <915505626@qq.com> Co-authored-by: Semyon Bevzyuk * remove roi_align plugin for ORT (#258) * remove roi_align plugin * remove ut * skip single_roi_extractor UT for ORT in CI * move align to symbolic and update docs * recover UT * resolve comments * add mmcls example * add mmcls/mmdet/mmseg and their corresponding tests * add test data * simplify test data * add requirement in optional.txt * fix setup problem when adding mmrazor requirement * use get_codebase_config * change mmrazor requirement Co-authored-by: AllentDan <41138331+AllentDan@users.noreply.github.com> Co-authored-by: grimoire Co-authored-by: grimoire Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com> Co-authored-by: Johannes L Co-authored-by: RunningLeon Co-authored-by: VVsssssk <88368822+VVsssssk@users.noreply.github.com> Co-authored-by: lvhan028 Co-authored-by: Yifan Zhou Co-authored-by: 杨培文 (Yang Peiwen) <915505626@qq.com> Co-authored-by: Semyon Bevzyuk --- mmdeploy/codebase/base/task.py | 21 ++++++- .../codebase/mmcls/deploy/classification.py | 6 +- .../codebase/mmdet/deploy/object_detection.py | 6 +- .../codebase/mmseg/deploy/segmentation.py | 6 +- requirements/optional.txt | 1 + .../test_mmcls/data/mmrazor_model.py | 31 ++++++++++ .../test_mmcls/data/mmrazor_mutable_cfg.yaml | 60 +++++++++++++++++++ .../test_mmcls/test_classification.py | 30 +++++++++- .../test_mmdet/data/mmrazor_model.py | 34 +++++++++++ .../test_mmdet/data/mmrazor_mutable_cfg.yaml | 60 +++++++++++++++++++ .../test_mmdet/test_object_detection.py | 30 +++++++++- .../test_mmseg/data/mmrazor_model.py | 28 +++++++++ .../test_mmseg/test_segmentation.py | 29 ++++++++- 13 files changed, 332 insertions(+), 10 deletions(-) create mode 100644 tests/test_codebase/test_mmcls/data/mmrazor_model.py create mode 100644 tests/test_codebase/test_mmcls/data/mmrazor_mutable_cfg.yaml create mode 100644 tests/test_codebase/test_mmdet/data/mmrazor_model.py create mode 100644 tests/test_codebase/test_mmdet/data/mmrazor_mutable_cfg.yaml create mode 100644 tests/test_codebase/test_mmseg/data/mmrazor_model.py diff --git a/mmdeploy/codebase/base/task.py b/mmdeploy/codebase/base/task.py index 6d0cdaaf8..330a61c0a 100644 --- a/mmdeploy/codebase/base/task.py +++ b/mmdeploy/codebase/base/task.py @@ -7,7 +7,8 @@ import torch from torch.utils.data import DataLoader, Dataset -from mmdeploy.utils import get_backend_config, get_codebase, get_root_logger +from mmdeploy.utils import (get_backend_config, get_codebase, + get_codebase_config, get_root_logger) from mmdeploy.utils.dataset import is_can_sort_dataset, sort_dataset @@ -284,3 +285,21 @@ def get_model_name(self) -> str: str: the name of the model. """ pass + + @property + def from_mmrazor(self) -> bool: + """Whether the codebase from mmrazor. + + Returns: + bool: From mmrazor or not. + + Raises: + TypeError: An error when type of `from_mmrazor` is not boolean. + """ + codebase_config = get_codebase_config(self.deploy_cfg) + from_mmrazor = codebase_config.get('from_mmrazor', False) + if not isinstance(from_mmrazor, bool): + raise TypeError('`from_mmrazor` attribute must be boolean type! ' + f'but got: {from_mmrazor}') + + return from_mmrazor diff --git a/mmdeploy/codebase/mmcls/deploy/classification.py b/mmdeploy/codebase/mmcls/deploy/classification.py index 92d8a410f..8e91c2f4c 100644 --- a/mmdeploy/codebase/mmcls/deploy/classification.py +++ b/mmdeploy/codebase/mmcls/deploy/classification.py @@ -95,7 +95,11 @@ def init_pytorch_model(self, nn.Module: An initialized torch model generated by OpenMMLab codebases. """ - from mmcls.apis import init_model + if self.from_mmrazor: + from mmrazor.apis import init_mmcls_model as init_model + else: + from mmcls.apis import init_model + model = init_model(self.model_cfg, model_checkpoint, self.device, cfg_options) diff --git a/mmdeploy/codebase/mmdet/deploy/object_detection.py b/mmdeploy/codebase/mmdet/deploy/object_detection.py index 7ba3e4955..baf388a75 100644 --- a/mmdeploy/codebase/mmdet/deploy/object_detection.py +++ b/mmdeploy/codebase/mmdet/deploy/object_detection.py @@ -89,7 +89,11 @@ def init_pytorch_model(self, nn.Module: An initialized torch model generated by other OpenMMLab codebases. """ - from mmdet.apis import init_detector + if self.from_mmrazor: + from mmrazor.apis import init_mmdet_model as init_detector + else: + from mmdet.apis import init_detector + model = init_detector(self.model_cfg, model_checkpoint, self.device, cfg_options) return model.eval() diff --git a/mmdeploy/codebase/mmseg/deploy/segmentation.py b/mmdeploy/codebase/mmseg/deploy/segmentation.py index a3f1728ae..26d210538 100644 --- a/mmdeploy/codebase/mmseg/deploy/segmentation.py +++ b/mmdeploy/codebase/mmseg/deploy/segmentation.py @@ -89,7 +89,11 @@ def init_pytorch_model(self, codebases. """ from mmcv.cnn.utils import revert_sync_batchnorm - from mmseg.apis import init_segmentor + if self.from_mmrazor: + from mmrazor.apis import init_mmseg_model as init_segmentor + else: + from mmseg.apis import init_segmentor + model = init_segmentor(self.model_cfg, model_checkpoint, self.device) model = revert_sync_batchnorm(model) return model.eval() diff --git a/requirements/optional.txt b/requirements/optional.txt index 8e52e01a6..68d2edba8 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -3,6 +3,7 @@ mmdet>=2.19.0,<=2.20.0 mmedit mmocr>=0.3.0,<=0.4.1 mmpose>=0.24.0 +mmrazor>=0.3.0 mmsegmentation onnxruntime>=1.8.0 openvino-dev diff --git a/tests/test_codebase/test_mmcls/data/mmrazor_model.py b/tests/test_codebase/test_mmcls/data/mmrazor_model.py new file mode 100644 index 000000000..7c477685b --- /dev/null +++ b/tests/test_codebase/test_mmcls/data/mmrazor_model.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +_base_ = 'model.py' + +norm_cfg = dict(type='BN') + +mutator = dict( + type='OneShotMutator', + placeholder_mapping=dict( + all_blocks=dict( + type='OneShotOP', + choices=dict( + shuffle_3x3=dict( + type='ShuffleBlock', kernel_size=3, norm_cfg=norm_cfg), + shuffle_5x5=dict( + type='ShuffleBlock', kernel_size=5, norm_cfg=norm_cfg), + shuffle_7x7=dict( + type='ShuffleBlock', kernel_size=7, norm_cfg=norm_cfg), + shuffle_xception=dict( + type='ShuffleXception', norm_cfg=norm_cfg), + )))) + +algorithm = dict( + type='SPOS', + architecture=dict( + type='MMClsArchitecture', + model={{_base_.model}}, + ), + mutator=mutator, + distiller=None, + mutable_cfg='tests/test_codebase/test_mmcls/data/mmrazor_mutable_cfg.yaml', + retraining=True) diff --git a/tests/test_codebase/test_mmcls/data/mmrazor_mutable_cfg.yaml b/tests/test_codebase/test_mmcls/data/mmrazor_mutable_cfg.yaml new file mode 100644 index 000000000..024347f13 --- /dev/null +++ b/tests/test_codebase/test_mmcls/data/mmrazor_mutable_cfg.yaml @@ -0,0 +1,60 @@ +stage_0_block_0: + chosen: + - shuffle_7x7 +stage_0_block_1: + chosen: + - shuffle_5x5 +stage_0_block_2: + chosen: + - shuffle_3x3 +stage_0_block_3: + chosen: + - shuffle_5x5 +stage_1_block_0: + chosen: + - shuffle_7x7 +stage_1_block_1: + chosen: + - shuffle_3x3 +stage_1_block_2: + chosen: + - shuffle_7x7 +stage_1_block_3: + chosen: + - shuffle_3x3 +stage_2_block_0: + chosen: + - shuffle_7x7 +stage_2_block_1: + chosen: + - shuffle_3x3 +stage_2_block_2: + chosen: + - shuffle_7x7 +stage_2_block_3: + chosen: + - shuffle_xception +stage_2_block_4: + chosen: + - shuffle_3x3 +stage_2_block_5: + chosen: + - shuffle_3x3 +stage_2_block_6: + chosen: + - shuffle_3x3 +stage_2_block_7: + chosen: + - shuffle_3x3 +stage_3_block_0: + chosen: + - shuffle_xception +stage_3_block_1: + chosen: + - shuffle_7x7 +stage_3_block_2: + chosen: + - shuffle_xception +stage_3_block_3: + chosen: + - shuffle_xception diff --git a/tests/test_codebase/test_mmcls/test_classification.py b/tests/test_codebase/test_mmcls/test_classification.py index 6bc0856d8..6837537af 100644 --- a/tests/test_codebase/test_mmcls/test_classification.py +++ b/tests/test_codebase/test_mmcls/test_classification.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy import os from tempfile import NamedTemporaryFile, TemporaryDirectory +from typing import Any import mmcv import numpy as np @@ -37,9 +39,33 @@ img = np.random.rand(*img_shape, 3) -def test_init_pytorch_model(): +@pytest.mark.parametrize('from_mmrazor', [True, False, '123', 0]) +def test_init_pytorch_model(from_mmrazor: Any): from mmcls.models.classifiers.base import BaseClassifier - model = task_processor.init_pytorch_model(None) + if from_mmrazor is False: + _task_processor = task_processor + else: + _model_cfg_path = 'tests/test_codebase/test_mmcls/data/' \ + 'mmrazor_model.py' + _model_cfg = load_config(_model_cfg_path)[0] + _model_cfg.algorithm.architecture.model.type = 'mmcls.ImageClassifier' + _model_cfg.algorithm.architecture.model.backbone = dict( + type='SearchableShuffleNetV2', widen_factor=1.0) + _deploy_cfg = copy.deepcopy(deploy_cfg) + _deploy_cfg.codebase_config['from_mmrazor'] = from_mmrazor + _task_processor = build_task_processor(_model_cfg, _deploy_cfg, 'cpu') + + if not isinstance(from_mmrazor, bool): + with pytest.raises( + TypeError, + match='`from_mmrazor` attribute must be ' + 'boolean type! ' + f'but got: {from_mmrazor}'): + _ = _task_processor.from_mmrazor + return + assert from_mmrazor == _task_processor.from_mmrazor + + model = _task_processor.init_pytorch_model(None) assert isinstance(model, BaseClassifier) diff --git a/tests/test_codebase/test_mmdet/data/mmrazor_model.py b/tests/test_codebase/test_mmdet/data/mmrazor_model.py new file mode 100644 index 000000000..e00e1c67d --- /dev/null +++ b/tests/test_codebase/test_mmdet/data/mmrazor_model.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +_base_ = 'model.py' + +norm_cfg = dict(type='BN', requires_grad=True) +mutator = dict( + type='OneShotMutator', + placeholder_mapping=dict( + all_blocks=dict( + type='OneShotOP', + choices=dict( + shuffle_3x3=dict( + type='ShuffleBlock', norm_cfg=norm_cfg, kernel_size=3), + shuffle_5x5=dict( + type='ShuffleBlock', norm_cfg=norm_cfg, kernel_size=5), + shuffle_7x7=dict( + type='ShuffleBlock', norm_cfg=norm_cfg, kernel_size=7), + shuffle_xception=dict( + type='ShuffleXception', + norm_cfg=norm_cfg, + ), + )))) + +algorithm = dict( + type='DetNAS', + architecture=dict( + type='MMDetArchitecture', + model={{_base_.model}}, + ), + mutator=mutator, + pruner=None, + distiller=None, + retraining=True, + mutable_cfg='tests/test_codebase/test_mmdet/data/mmrazor_mutable_cfg.yaml', +) diff --git a/tests/test_codebase/test_mmdet/data/mmrazor_mutable_cfg.yaml b/tests/test_codebase/test_mmdet/data/mmrazor_mutable_cfg.yaml new file mode 100644 index 000000000..5321759fb --- /dev/null +++ b/tests/test_codebase/test_mmdet/data/mmrazor_mutable_cfg.yaml @@ -0,0 +1,60 @@ +stage_0_block_0: + chosen: + - shuffle_7x7 +stage_0_block_1: + chosen: + - shuffle_5x5 +stage_0_block_2: + chosen: + - shuffle_7x7 +stage_0_block_3: + chosen: + - shuffle_3x3 +stage_1_block_0: + chosen: + - shuffle_7x7 +stage_1_block_1: + chosen: + - shuffle_5x5 +stage_1_block_2: + chosen: + - shuffle_5x5 +stage_1_block_3: + chosen: + - shuffle_7x7 +stage_2_block_0: + chosen: + - shuffle_xception +stage_2_block_1: + chosen: + - shuffle_xception +stage_2_block_2: + chosen: + - shuffle_5x5 +stage_2_block_3: + chosen: + - shuffle_xception +stage_2_block_4: + chosen: + - shuffle_3x3 +stage_2_block_5: + chosen: + - shuffle_3x3 +stage_2_block_6: + chosen: + - shuffle_xception +stage_2_block_7: + chosen: + - shuffle_5x5 +stage_3_block_0: + chosen: + - shuffle_xception +stage_3_block_1: + chosen: + - shuffle_5x5 +stage_3_block_2: + chosen: + - shuffle_xception +stage_3_block_3: + chosen: + - shuffle_7x7 diff --git a/tests/test_codebase/test_mmdet/test_object_detection.py b/tests/test_codebase/test_mmdet/test_object_detection.py index 2fd40a2ae..a962f21d6 100644 --- a/tests/test_codebase/test_mmdet/test_object_detection.py +++ b/tests/test_codebase/test_mmdet/test_object_detection.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy import os from tempfile import NamedTemporaryFile, TemporaryDirectory +from typing import Any import mmcv import numpy as np @@ -48,9 +50,33 @@ img = np.random.rand(*img_shape, 3) -def test_init_pytorch_model(): +@pytest.mark.parametrize('from_mmrazor', [True, False, '123', 0]) +def test_init_pytorch_model(from_mmrazor: Any): from mmdet.models import BaseDetector - model = task_processor.init_pytorch_model(None) + if from_mmrazor is False: + _task_processor = task_processor + else: + _model_cfg_path = 'tests/test_codebase/test_mmdet/data/' \ + 'mmrazor_model.py' + _model_cfg = load_config(_model_cfg_path)[0] + _model_cfg.algorithm.architecture.model.type = 'mmdet.YOLOV3' + _model_cfg.algorithm.architecture.model.backbone.type = \ + 'mmcls.SearchableShuffleNetV2' + _deploy_cfg = copy.deepcopy(deploy_cfg) + _deploy_cfg.codebase_config['from_mmrazor'] = from_mmrazor + _task_processor = build_task_processor(_model_cfg, _deploy_cfg, 'cpu') + + if not isinstance(from_mmrazor, bool): + with pytest.raises( + TypeError, + match='`from_mmrazor` attribute must be ' + 'boolean type! ' + f'but got: {from_mmrazor}'): + _ = _task_processor.from_mmrazor + return + assert from_mmrazor == _task_processor.from_mmrazor + + model = _task_processor.init_pytorch_model(None) assert isinstance(model, BaseDetector) diff --git a/tests/test_codebase/test_mmseg/data/mmrazor_model.py b/tests/test_codebase/test_mmseg/data/mmrazor_model.py new file mode 100644 index 000000000..d86ae1e54 --- /dev/null +++ b/tests/test_codebase/test_mmseg/data/mmrazor_model.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +_base_ = 'model.py' + +# algorithm setting +algorithm = dict( + type='GeneralDistill', + architecture=dict( + type='MMSegArchitecture', + model={{_base_.model}}, + ), + distiller=dict( + type='SingleTeacherDistiller', + teacher={{_base_.model}}, + teacher_trainable=False, + components=[ + dict( + student_module='decode_head.conv_seg', + teacher_module='decode_head.conv_seg', + losses=[ + dict( + type='ChannelWiseDivergence', + name='loss_cwd_logits', + tau=1, + loss_weight=5, + ) + ]) + ]), +) diff --git a/tests/test_codebase/test_mmseg/test_segmentation.py b/tests/test_codebase/test_mmseg/test_segmentation.py index 008d02fb1..3884c1f23 100644 --- a/tests/test_codebase/test_mmseg/test_segmentation.py +++ b/tests/test_codebase/test_mmseg/test_segmentation.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy import os from tempfile import NamedTemporaryFile, TemporaryDirectory +from typing import Any import mmcv import numpy as np @@ -37,9 +39,32 @@ img = np.random.rand(*img_shape, 3) -def test_init_pytorch_model(): +@pytest.mark.parametrize('from_mmrazor', [True, False, '123', 0]) +def test_init_pytorch_model(from_mmrazor: Any): from mmseg.models.segmentors.base import BaseSegmentor - model = task_processor.init_pytorch_model(None) + if from_mmrazor is False: + _task_processor = task_processor + else: + _model_cfg_path = 'tests/test_codebase/test_mmseg/data/' \ + 'mmrazor_model.py' + _model_cfg = load_config(_model_cfg_path)[0] + _model_cfg.algorithm.architecture.model.type = 'mmseg.EncoderDecoder' + _model_cfg.algorithm.distiller.teacher.type = 'mmseg.EncoderDecoder' + _deploy_cfg = copy.deepcopy(deploy_cfg) + _deploy_cfg.codebase_config['from_mmrazor'] = from_mmrazor + _task_processor = build_task_processor(_model_cfg, _deploy_cfg, 'cpu') + + if not isinstance(from_mmrazor, bool): + with pytest.raises( + TypeError, + match='`from_mmrazor` attribute must be ' + 'boolean type! ' + f'but got: {from_mmrazor}'): + _ = _task_processor.from_mmrazor + return + assert from_mmrazor == _task_processor.from_mmrazor + + model = _task_processor.init_pytorch_model(None) assert isinstance(model, BaseSegmentor)