From 8c9aaac85d4a1027c72c66657baa5c808a3f38b9 Mon Sep 17 00:00:00 2001 From: Amit Klinger Date: Sun, 17 Dec 2023 12:11:15 +0200 Subject: [PATCH] Working on new checkpoint hook --- configs/fcn/fcn_hailo_10classes.py | 4 +- configs/fcn/fcn_hailo_10classes_epoch.py | 80 ++++++++++++++++++++++++ configs/fcn/fcn_hailo_prune.py | 20 +++--- mmseg/engine/hooks/checkpoint_hook.py | 19 ++++++ mmseg/utils/misc.py | 60 +++++++++++++++++- recipes/recipe_yolox_hailo_pruning.md | 30 +-------- sparsity/sparseml_hook.py | 35 +++++------ tools/pytorch2onnx.py | 65 ++----------------- tools/test.py | 19 +++++- tools/train.py | 2 +- 10 files changed, 214 insertions(+), 120 deletions(-) create mode 100644 configs/fcn/fcn_hailo_10classes_epoch.py create mode 100644 mmseg/engine/hooks/checkpoint_hook.py diff --git a/configs/fcn/fcn_hailo_10classes.py b/configs/fcn/fcn_hailo_10classes.py index 2c19faf865..663fc1356a 100644 --- a/configs/fcn/fcn_hailo_10classes.py +++ b/configs/fcn/fcn_hailo_10classes.py @@ -23,8 +23,8 @@ # default hooks - logger & checkpoint configs default_hooks = dict( - # print log every 100 iterations. - logger=dict(type='LoggerHook', interval=200, log_metric_by_epoch=False), + # print log every 400 iterations. + logger=dict(type='LoggerHook', interval=400, log_metric_by_epoch=False), # enable the parameter scheduler. param_scheduler=dict(type='ParamSchedulerHook'), diff --git a/configs/fcn/fcn_hailo_10classes_epoch.py b/configs/fcn/fcn_hailo_10classes_epoch.py new file mode 100644 index 0000000000..b4d212e6f3 --- /dev/null +++ b/configs/fcn/fcn_hailo_10classes_epoch.py @@ -0,0 +1,80 @@ +# model settings +_base_ = [ + '../_base_/datasets/cityscapes10classes.py', '../_base_/default_runtime.py', +] + +# optimizer +optimizer = dict(type='Adam', lr=0.001, weight_decay=1e-5) +optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None) + +# learning policy +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.2, begin=0, end=1), + dict( + type='CosineAnnealingLR', begin=1, end=5, eta_min=0.00001) +] + +# runtime settings +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=5, val_interval=1) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +# default hooks - logger & checkpoint configs +default_hooks = dict( + + # print log every 100 iterations. + logger=dict(type='LoggerHook', interval=1), #, log_metric_by_epoch=False), + + # enable the parameter scheduler. + param_scheduler=dict(type='ParamSchedulerHook'), + + # save checkpoint every 5 epochs. + checkpoint=dict(type='CheckpointHook', + interval=1, + save_best='mIoU', + rule='greater', + max_keep_ckpts=5), +) + +# tensorboard vis ('LocalVisBackend' might be redundant) save_dir='./tf_dir/' +visualizer = dict(type='SegLocalVisualizer', + vis_backends=[dict(type='LocalVisBackend'), dict(type='TensorboardVisBackend')], + name='visualizer') + +# data preprocessing +norm_cfg = dict(type='SyncBN', requires_grad=True) +crop_size = (512, 1024) +data_preprocessor = dict( + type='SegDataPreProcessor', + mean=[0.0, 0.0, 0.0], + std=[1.0, 1.0, 1.0], + bgr_to_rgb=True, + pad_val=0, + seg_pad_val=255, + size=crop_size) + +model = dict( + type='EncoderDecoder', + backbone=dict( + type='hailoFPN', + depth=0.33, + width=0.125, + bb_channels_list=[128, 256, 512, 1024], + bb_num_repeats_list=[9, 15, 21, 12], + neck_channels_list=[256, 128, 128, 256, 256, 512], + neck_num_repeats_list=[9, 12, 12, 9]), + decode_head=dict( + type='ConvHead', + in_channels=16, + channels=128, + num_convs=1, + num_classes=10, + norm_cfg=norm_cfg, + align_corners=True, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole'), + infer_wo_softmax=True) diff --git a/configs/fcn/fcn_hailo_prune.py b/configs/fcn/fcn_hailo_prune.py index ab5f84be83..97a6bd5ac6 100644 --- a/configs/fcn/fcn_hailo_prune.py +++ b/configs/fcn/fcn_hailo_prune.py @@ -4,7 +4,8 @@ ] resume = True -load_from='./work_dirs/fcn_hailo_eta1e5/iter_68448.pth' +# load_from='./work_dirs/fcn_hailo_eta1e5/iter_68448.pth' # best checkpoint path of full training (fcn_hailo_10classes). Start of pruning procedure +load_from='./work_dirs/fcn_hailo_eta1e5_eve/iter_74400.pth' # optimizer optimizer = dict(type='Adam', lr=0.0001, weight_decay=1e-5) @@ -12,7 +13,7 @@ # runtime settings -train_cfg = dict(type='IterBasedTrainLoop', max_iters=173760, val_interval=1488) # 74400 (50 epochs), 89280 (60 epochs), 104160 (70 epochs), 119040 (80 epochs) +train_cfg = dict(type='IterBasedTrainLoop', max_iters=178560, val_interval=1488) # 74400 (50 epochs), 89280 (60 epochs), 104160 (70 epochs), 89280 (80 epochs), 173760 val_cfg = dict(type='ValLoop') test_cfg = dict(type='TestLoop') @@ -24,15 +25,20 @@ # enable the parameter scheduler. param_scheduler=dict(type='ParamSchedulerHook'), + ) - # save checkpoint every 1 epoch. - checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=1488, save_best='mIoU', rule='greater', - max_keep_ckpts=5, save_begin=163680), # 2976 (2Epoches), 7440 (5 Epoches) , max_keep_ckpts=5 - ) + # # save checkpoint every 1 epoch. + # checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=1488, save_best='mIoU', rule='greater', + # max_keep_ckpts=5, save_begin=163680), # 2976 (2Epoches), 7440 (5 Epoches) + # ) # learning policy: taken from the recipe # custom hooks -custom_hooks = [dict(type='SparseMLHook', interval=10, priority='NORMAL')] +sparseml_hook = dict(type='SparseMLHook', priority='NORMAL') +# sparseml_hook = dict(type='SparseMLHook', interval=10, priority='NORMAL') +ext_checkpoint_hook = dict(type='ExtCheckpointHook', by_epoch=False, interval=1488, save_best='mIoU', rule='greater', + max_keep_ckpts=5, save_begin=163680) # 2976 (2Epoches), 7440 (5 Epoches), 80352 (54), 83328 (56), 163680 +custom_hooks = [sparseml_hook, ext_checkpoint_hook] # tensorboard vis ('LocalVisBackend' might be redundant) save_dir='./tf_dir/' visualizer = dict(type='SegLocalVisualizer', diff --git a/mmseg/engine/hooks/checkpoint_hook.py b/mmseg/engine/hooks/checkpoint_hook.py new file mode 100644 index 0000000000..b7b820759e --- /dev/null +++ b/mmseg/engine/hooks/checkpoint_hook.py @@ -0,0 +1,19 @@ +from mmengine.hooks import CheckpointHook +from mmseg.registry import HOOKS + + +@HOOKS.register_module() +class ExtCheckpointHook(CheckpointHook): + # def __init__(self): + # self.by_epoch = False + + def after_val_epoch(self, runner, metrics): + if runner.iter == self.save_begin: + runner.logger.info('Resetting best_score to 0.0') + runner.message_hub.update_info('best_score', 0.0) + runner.message_hub.pop_info('best_ckpt', None) + if (runner.iter + 1 >= self.save_begin): + runner.logger.info('ExtCheckpointHook ExtCheckpointHook ExtCheckpointHook') + runner.logger.info( + f'Saving checkpoint at iter {runner.iter}') + super().after_val_epoch(runner, metrics) diff --git a/mmseg/utils/misc.py b/mmseg/utils/misc.py index 0a561732e9..a9d890c055 100644 --- a/mmseg/utils/misc.py +++ b/mmseg/utils/misc.py @@ -4,7 +4,7 @@ import numpy as np import torch import torch.nn.functional as F - +from collections import OrderedDict from .typing_utils import SampleList @@ -116,3 +116,61 @@ def stack_batch(inputs: List[torch.Tensor], pad_shape=pad_img.shape[-2:])) return torch.stack(padded_inputs, dim=0), padded_samples + + +def load_pretrained_weights_soft(model, checkpoint, logger): + + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + + model_dict = model.state_dict() + new_state_dict = OrderedDict() + matched_layers, discarded_layers = [], [] + + for k, v in state_dict.items(): + if k.startswith('module.'): + k = k[7:] # discard module. + + if k in model_dict and model_dict[k].size() == v.size(): + new_state_dict[k] = v + matched_layers.append(k) + else: + discarded_layers.append(k) + + model_dict.update(new_state_dict) + model.load_state_dict(model_dict) + + if len(matched_layers) == 0: + logger.warning( + 'The pretrained weights cannot be loaded, ' + 'please check the key names manually ' + ) + else: + logger.info('Successfully loaded pretrained weights') + if len(discarded_layers) > 0: + logger.warning( + '** The following layers are discarded ' + 'due to unmatched keys or layer size: {}'. + format(discarded_layers) + ) + return + + +def calc_sparsity(model_dict, logger, verbose=False): + weights_layers_num, total_weights, total_zeros = 0, 0, 0 + prefix = next(iter(model_dict)).split('backbone.stage0')[0] + for k, v in model_dict.items(): + if k.startswith(prefix) and k.endswith('weight'): + weights_layers_num += 1 + total_weights += v.numel() + total_zeros += (v.numel() - v.count_nonzero()) + zeros_ratio = (v.numel() - v.count_nonzero()) / v.numel() * 100.0 + if verbose: + logger.info(f"[{weights_layers_num:>2}] {k:<51}:: {v.numel() - v.count_nonzero():<5} / {v.numel():<7}" + f" ({zeros_ratio:<4.1f}%) are zeros") + logger.info(f"Model has {weights_layers_num} weight layers") + logger.info(f"Overall Sparsity is roughly: {100 * total_zeros / total_weights:.1f}%") diff --git a/recipes/recipe_yolox_hailo_pruning.md b/recipes/recipe_yolox_hailo_pruning.md index 2b389524a5..bb2f649738 100644 --- a/recipes/recipe_yolox_hailo_pruning.md +++ b/recipes/recipe_yolox_hailo_pruning.md @@ -16,7 +16,7 @@ init_sparsity: 0.01 final_sparsity: 0.60 pruning_start_epoch: 60 pruning_end_epoch: 110 -pruning_update_frequency: 5.0 +pruning_update_frequency: 2.0 #Modifiers training_modifiers: @@ -38,31 +38,3 @@ pruning_modifiers: end_epoch: eval(pruning_end_epoch) update_frequency: eval(pruning_update_frequency) --- - -training_modifiers: - - !EpochRangeModifier - start_epoch: 0 - end_epoch: eval(num_epochs) - - - !LearningRateFunctionModifier - start_epoch: 3 - end_epoch: eval(num_epochs) - lr_func: linear - init_lr: eval(init_lr) - final_lr: eval(final_lr) - - - !LearningRateFunctionModifier - start_epoch: 0 - end_epoch: 3 - lr_func: linear - init_lr: eval(weights_warmup_lr) - final_lr: eval(init_lr) - param_groups: [0, 1] - - - !LearningRateFunctionModifier - start_epoch: 0 - end_epoch: 3 - lr_func: linear - init_lr: eval(biases_warmup_lr) - final_lr: eval(init_lr) - param_groups: [2] \ No newline at end of file diff --git a/sparsity/sparseml_hook.py b/sparsity/sparseml_hook.py index 59786a166c..6462c9f12d 100644 --- a/sparsity/sparseml_hook.py +++ b/sparsity/sparseml_hook.py @@ -1,26 +1,33 @@ -from mmseg.registry import RUNNERS, HOOKS +from mmseg.registry import HOOKS +from mmseg.utils.misc import calc_sparsity from mmengine.hooks import Hook from sparseml.pytorch.optim import ScheduledModifierManager + @HOOKS.register_module() class SparseMLHook(Hook): - def __init__(self, interval=10): - self.interval = interval + def __init__(self, steps_per_epoch=1488, start_epoch=50, prune_interval_epoch=2): + self.steps_per_epoch = steps_per_epoch + self.start_epoch = start_epoch + self.prune_interval_epoch = prune_interval_epoch def before_train(self, runner) -> None: self.manager = ScheduledModifierManager.from_yaml(runner.cfg.recipe) optimizer = runner.optim_wrapper.optimizer - optimizer = self.manager.modify(runner.model.module, optimizer, steps_per_epoch=1488, epoch=40) + optimizer = self.manager.modify(runner.model.module, + optimizer, + steps_per_epoch=self.steps_per_epoch, + epoch=self.start_epoch) runner.optim_wrapper.optimizer = optimizer def after_train(self, runner) -> None: self.manager.finalize(runner.model.module) def after_train_iter(self, runner, batch_idx, data_batch, outputs): - if batch_idx % (1488 * 2) == 0: # 2 Epochs - runner.logger.info(f"Epoch #{batch_idx // 1488} End") - self._calc_sparsity(runner.model.state_dict(), runner.logger) + if batch_idx % (self.steps_per_epoch * self.prune_interval_epoch) == 0: # 2 Epochs + calc_sparsity(runner.model.state_dict(), runner.logger) + runner.logger.info(f"Epoch #{batch_idx // self.steps_per_epoch} End") def after_test_epoch(self, runner, metrics): runner.logger.info("Switching to deployment model") @@ -28,16 +35,4 @@ def after_test_epoch(self, runner, metrics): for module in runner.model.modules(): if hasattr(module, 'switch_to_deploy'): module.switch_to_deploy() - self._calc_sparsity(runner.model.state_dict(), runner.logger) - - def _calc_sparsity(self, model_dict, logger): - weights_layers_num, total_weights, total_zeros = 0, 0, 0 - prefix = next(iter(model_dict)).split('backbone.stage0')[0] - for k, v in model_dict.items(): - if k.startswith(prefix) and k.endswith('weight'): - weights_layers_num += 1 - total_weights += v.numel() - total_zeros += (v.numel() - v.count_nonzero()) - logger.info(f"Model has {weights_layers_num} weight layers") - logger.info(f"Overall Sparsity is roughly: {100 * total_zeros / total_weights:.1f}%") - + calc_sparsity(runner.model.state_dict(), runner.logger, True) diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index a313ad57ec..e223ceda72 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -8,61 +8,20 @@ from mmengine.config import Config, DictAction from mmengine.runner import Runner from mmseg.registry import RUNNERS +from mmseg.utils.misc import calc_sparsity, load_pretrained_weights_soft import onnx from onnxsim import simplify -from collections import OrderedDict - -def load_pretrained_weights_soft(model, checkpoint, logger): - - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] - elif 'model' in checkpoint: - state_dict = checkpoint['model'] - else: - state_dict = checkpoint - - model_dict = model.state_dict() - new_state_dict = OrderedDict() - matched_layers, discarded_layers = [], [] - - for k, v in state_dict.items(): - if k.startswith('module.'): - k = k[7:] # discard module. - - if k in model_dict and model_dict[k].size() == v.size(): - new_state_dict[k] = v - matched_layers.append(k) - else: - discarded_layers.append(k) - - model_dict.update(new_state_dict) - model.load_state_dict(model_dict) - - if len(matched_layers) == 0: - logger.warning( - 'The pretrained weights cannot be loaded, ' - 'please check the key names manually ' - ) - else: - logger.info('Successfully loaded pretrained weights') - if len(discarded_layers) > 0: - logger.warning( - '** The following layers are discarded ' - 'due to unmatched keys or layer size: {}'. - format(discarded_layers) - ) - def dummy_prune_ckpt(ckpt, prune_ratio=0.5, random_prune=False): + prefix = next(iter(ckpt['state_dict'])).split('backbone.stage0')[0] for k, v in ckpt['state_dict'].items(): - if k.startswith('backbone.') and k.endswith('.rbr_dense.conv.weight'): + if k.startswith(prefix) and k.endswith('.rbr_dense.conv.weight'): if random_prune: # Sparsify layer randomly: v = random_prune_layer(v, prune_ratio) else: # Sparsify layer according to magnitude: v = dummy_prune_layer(v, prune_ratio) - calc_sparsity(ckpt['state_dict']) return ckpt @@ -108,18 +67,6 @@ def dummy_prune_layer(layer, prune_ratio=0.5): return pruned_tensor -def calc_sparsity(model_dict, logger): - weights_layers_num, total_weights, total_zeros = 0, 0, 0 - for k, v in model_dict.items(): - if k.startswith('backbone.') and k.endswith('weight'): - weights_layers_num += 1 - total_weights += v.numel() - total_zeros += (v.numel() - v.count_nonzero()) - zeros_ratio = (v.numel() - v.count_nonzero()) / v.numel() * 100.0 - logger.info(f"[{weights_layers_num:>2}] {k:<51}:: {v.numel() - v.count_nonzero():<5} / {v.numel():<7} ({zeros_ratio:<4.1f}%) are zeros") - logger.info(f"Model has {weights_layers_num} weight layers") - logger.info(f"Overall Sparsity is roughly: {100 * total_zeros / total_weights:.1f}%") - def parse_args(): parser.add_argument('config', help='train config file path') @@ -216,13 +163,13 @@ def main(): model.load_state_dict(ckpt['state_dict']) else: model.load_state_dict(ckpt) - + runner.logger.info("Switching to deployment model") # if repvgg style -> deploy for module in model.modules(): if hasattr(module, 'switch_to_deploy'): module.switch_to_deploy() - calc_sparsity(model.state_dict(), runner.logger) + calc_sparsity(model.state_dict(), runner.logger, True) # to onnx model.eval() @@ -231,7 +178,7 @@ def main(): model_with_postprocess = ModelWithPostProc(model, args) model_with_postprocess.eval() - imgs = torch.zeros(1,3, args.shape[0], args.shape[1], dtype=torch.float32).to(device) + imgs = torch.zeros(1, 3, args.shape[0], args.shape[1], dtype=torch.float32).to(device) outputs = model_with_postprocess(imgs) torch.onnx.export(model_with_postprocess, diff --git a/tools/test.py b/tools/test.py index 19fa17fd07..787c5200ae 100644 --- a/tools/test.py +++ b/tools/test.py @@ -2,10 +2,14 @@ import argparse import os import os.path as osp +import torch +from copy import deepcopy from sparsity import sparseml_hook from mmengine.config import Config, DictAction from mmengine.runner import Runner +from mmseg.engine.hooks import checkpoint_hook +from mmseg.utils.misc import calc_sparsity, load_pretrained_weights_soft # TODO: support fuse_conv_bn, visualization, and format_only @@ -29,6 +33,10 @@ def parse_args(): help='directory where painted images will be saved. ' 'If specified, it will be automatically saved ' 'to the work_dir/timestamp/show_dir') + parser.add_argument( + '--deploy', + action='store_true', + help='switch model to deployment mode and calculate sparsity ratio') parser.add_argument( '--wait-time', type=float, default=2, help='the interval of show (s)') parser.add_argument( @@ -115,7 +123,16 @@ def main(): # build the runner from config runner = Runner.from_cfg(cfg) - + if args.deploy: + ckpt = torch.load(args.checkpoint, map_location='cpu') + model_deploy = deepcopy(runner.model) + load_pretrained_weights_soft(model_deploy, ckpt, runner.logger) + runner.logger.info("Calculating sparsity ratio on deployment model") + # if repvgg style -> deploy + for module in model_deploy.modules(): + if hasattr(module, 'switch_to_deploy'): + module.switch_to_deploy() + calc_sparsity(model_deploy.state_dict(), runner.logger, True) # start testing runner.test() diff --git a/tools/train.py b/tools/train.py index b5ab30daa5..88634829f8 100644 --- a/tools/train.py +++ b/tools/train.py @@ -5,6 +5,7 @@ import os.path as osp from sparsity import sparseml_hook +from mmseg.engine.hooks import checkpoint_hook from mmengine.config import Config, DictAction from mmengine.logging import print_log @@ -93,7 +94,6 @@ def main(): cfg.resume = args.resume cfg.recipe = args.recipe cfg.recipe_args = args.recipe_args - print(f"{cfg.resume=}, {cfg.load_from}") # build the runner from config if 'runner_type' not in cfg: # build the default runner