diff --git a/configs/_base_/datasets/cityscapes.py b/configs/_base_/datasets/cityscapes.py index 2f88208b51..30912cb7bc 100644 --- a/configs/_base_/datasets/cityscapes.py +++ b/configs/_base_/datasets/cityscapes.py @@ -1,6 +1,6 @@ # dataset settings dataset_type = 'CityscapesDataset' -data_root = '/data/cityscapes/' +data_root = '/data/cityscapes10classes/' crop_size = (512, 1024) train_pipeline = [ dict(type='LoadImageFromFile'), diff --git a/configs/fcn/fcn_hailo_10classes.py b/configs/fcn/fcn_hailo_10classes.py index 85946b4e32..9105395ae4 100644 --- a/configs/fcn/fcn_hailo_10classes.py +++ b/configs/fcn/fcn_hailo_10classes.py @@ -12,30 +12,32 @@ dict( type='LinearLR', start_factor=0.2, by_epoch=False, begin=0, end=7440), dict( - type='CosineAnnealingLR', begin=7440, by_epoch=False, end=59520) + type='CosineAnnealingLR', begin=7440, end=74400, eta_min=0.00001, by_epoch=False) ] # runtime settings -train_cfg = dict(type='IterBasedTrainLoop', max_iters=59520, val_interval=1488) +train_cfg = dict(type='IterBasedTrainLoop', max_iters=74400, val_interval=1488) val_cfg = dict(type='ValLoop') test_cfg = dict(type='TestLoop') # default hooks - logger & checkpoint configs default_hooks = dict( - # print log every 100 iterations. - logger=dict(type='LoggerHook', interval=100, log_metric_by_epoch=False), + # print log every 400 iterations. + logger=dict(type='LoggerHook', interval=400, log_metric_by_epoch=False), # enable the parameter scheduler. param_scheduler=dict(type='ParamSchedulerHook'), # save checkpoint every 5 epochs. - checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=7440), + checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=7440, save_best='mIoU', rule='greater', + max_keep_ckpts=5), ) -# tensorboard vis -vis_backends = [dict(type='LocalVisBackend'), - dict(type='TensorboardVisBackend')] +# tensorboard vis ('LocalVisBackend' might be redundant) save_dir='./tf_dir/' +visualizer = dict(type='SegLocalVisualizer', + vis_backends=[dict(type='LocalVisBackend'), dict(type='TensorboardVisBackend')], + name='visualizer') # data preprocessing norm_cfg = dict(type='SyncBN', requires_grad=True) @@ -72,4 +74,4 @@ # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole'), - infer_wo_softmax=True) \ No newline at end of file + infer_wo_softmax=True) diff --git a/configs/fcn/fcn_hailo_10_classes_pp.py b/configs/fcn/fcn_hailo_prune.py similarity index 54% rename from configs/fcn/fcn_hailo_10_classes_pp.py rename to configs/fcn/fcn_hailo_prune.py index 63594d5d83..523b223fbc 100644 --- a/configs/fcn/fcn_hailo_10_classes_pp.py +++ b/configs/fcn/fcn_hailo_prune.py @@ -3,39 +3,42 @@ '../_base_/datasets/cityscapes10classes.py', '../_base_/default_runtime.py', ] +resume = True +# best checkpoint path of full training (fcn_hailo_10classes). Start of pruning procedure: +load_from = './work_dirs/fcn_hailo_eta1e5_eve/iter_74400.pth' + # optimizer -optimizer = dict(type='Adam', lr=0.001, weight_decay=1e-5) +optimizer = dict(type='Adam', lr=0.0001, weight_decay=1e-5) optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None) -# learning policy -param_scheduler = [ - dict( - type='LinearLR', start_factor=0.2, by_epoch=False, begin=0, end=7440), - dict( - type='CosineAnnealingLR', begin=7440, by_epoch=False, end=59520) -] # runtime settings -train_cfg = dict(type='IterBasedTrainLoop', max_iters=59520, val_interval=1488) +train_cfg = dict(type='IterBasedTrainLoop', max_iters=178560, val_interval=1488) # 74400 (50 epochs), 178560 (120) val_cfg = dict(type='ValLoop') test_cfg = dict(type='TestLoop') # default hooks - logger & checkpoint configs default_hooks = dict( - # print log every 100 iterations. - logger=dict(type='LoggerHook', interval=100, log_metric_by_epoch=False), + # print log every 500 iterations. + logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=False), # enable the parameter scheduler. param_scheduler=dict(type='ParamSchedulerHook'), + ) - # save checkpoint every 5 epochs. - checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=7440), -) +# learning policy: taken from the recipe +# custom hooks +sparseml_hook = dict(type='SparseMLHook', priority='NORMAL') +# Saving best checkpoint starts after pruning hits final ratio +ext_checkpoint_hook = dict(type='ExtCheckpointHook', by_epoch=False, interval=1488, save_best='mIoU', rule='greater', + max_keep_ckpts=5, save_begin=163680) # 163680 (110 epochs) +custom_hooks = [sparseml_hook, ext_checkpoint_hook] -# tensorboard vis -vis_backends = [dict(type='LocalVisBackend'), - dict(type='TensorboardVisBackend')] +# tensorboard vis ('LocalVisBackend' might be redundant) save_dir='./tf_dir/' +visualizer = dict(type='SegLocalVisualizer', + vis_backends=[dict(type='LocalVisBackend'), dict(type='TensorboardVisBackend')], + name='visualizer') # data preprocessing norm_cfg = dict(type='SyncBN', requires_grad=True) @@ -60,7 +63,7 @@ neck_channels_list=[256, 128, 128, 256, 256, 512], neck_num_repeats_list=[9, 12, 12, 9]), decode_head=dict( - type='PostProcess', + type='ConvHead', in_channels=16, channels=128, num_convs=1, @@ -72,4 +75,4 @@ # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole'), - infer_wo_softmax=True) \ No newline at end of file + infer_wo_softmax=True) diff --git a/mmseg/engine/hooks/checkpoint_hook.py b/mmseg/engine/hooks/checkpoint_hook.py new file mode 100644 index 0000000000..d752fd839e --- /dev/null +++ b/mmseg/engine/hooks/checkpoint_hook.py @@ -0,0 +1,16 @@ +from mmengine.hooks import CheckpointHook +from mmseg.registry import HOOKS + + +@HOOKS.register_module() +class ExtCheckpointHook(CheckpointHook): + + def after_val_epoch(self, runner, metrics): + if runner.iter == self.save_begin: + runner.logger.info('Resetting best_score to 0.0') + runner.message_hub.update_info('best_score', 0.0) + runner.message_hub.pop_info('best_ckpt', None) + if (runner.iter + 1 >= self.save_begin): + runner.logger.info( + f'Saving checkpoint at iter {runner.iter}') + super().after_val_epoch(runner, metrics) diff --git a/mmseg/utils/misc.py b/mmseg/utils/misc.py index 0a561732e9..a9d890c055 100644 --- a/mmseg/utils/misc.py +++ b/mmseg/utils/misc.py @@ -4,7 +4,7 @@ import numpy as np import torch import torch.nn.functional as F - +from collections import OrderedDict from .typing_utils import SampleList @@ -116,3 +116,61 @@ def stack_batch(inputs: List[torch.Tensor], pad_shape=pad_img.shape[-2:])) return torch.stack(padded_inputs, dim=0), padded_samples + + +def load_pretrained_weights_soft(model, checkpoint, logger): + + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + + model_dict = model.state_dict() + new_state_dict = OrderedDict() + matched_layers, discarded_layers = [], [] + + for k, v in state_dict.items(): + if k.startswith('module.'): + k = k[7:] # discard module. + + if k in model_dict and model_dict[k].size() == v.size(): + new_state_dict[k] = v + matched_layers.append(k) + else: + discarded_layers.append(k) + + model_dict.update(new_state_dict) + model.load_state_dict(model_dict) + + if len(matched_layers) == 0: + logger.warning( + 'The pretrained weights cannot be loaded, ' + 'please check the key names manually ' + ) + else: + logger.info('Successfully loaded pretrained weights') + if len(discarded_layers) > 0: + logger.warning( + '** The following layers are discarded ' + 'due to unmatched keys or layer size: {}'. + format(discarded_layers) + ) + return + + +def calc_sparsity(model_dict, logger, verbose=False): + weights_layers_num, total_weights, total_zeros = 0, 0, 0 + prefix = next(iter(model_dict)).split('backbone.stage0')[0] + for k, v in model_dict.items(): + if k.startswith(prefix) and k.endswith('weight'): + weights_layers_num += 1 + total_weights += v.numel() + total_zeros += (v.numel() - v.count_nonzero()) + zeros_ratio = (v.numel() - v.count_nonzero()) / v.numel() * 100.0 + if verbose: + logger.info(f"[{weights_layers_num:>2}] {k:<51}:: {v.numel() - v.count_nonzero():<5} / {v.numel():<7}" + f" ({zeros_ratio:<4.1f}%) are zeros") + logger.info(f"Model has {weights_layers_num} weight layers") + logger.info(f"Overall Sparsity is roughly: {100 * total_zeros / total_weights:.1f}%") diff --git a/recipes/recipe_yolox_hailo_pruning.md b/recipes/recipe_yolox_hailo_pruning.md new file mode 100644 index 0000000000..bb2f649738 --- /dev/null +++ b/recipes/recipe_yolox_hailo_pruning.md @@ -0,0 +1,40 @@ + +--- + +version: 1.1.0 + +# General Hyperparams +start_epoch: 50 +num_epochs: 120 +init_lr: 0.00001 +final_lr: 0.00001 +weights_warmup_lr: 0 +biases_warmup_lr: 0 + +# Pruning Hyperparams +init_sparsity: 0.01 +final_sparsity: 0.60 +pruning_start_epoch: 60 +pruning_end_epoch: 110 +pruning_update_frequency: 2.0 + +#Modifiers +training_modifiers: + - !LearningRateFunctionModifier + start_epoch: eval(start_epoch) + end_epoch: eval(num_epochs) + lr_func: linear + init_lr: eval(init_lr) + final_lr: eval(init_lr) + +pruning_modifiers: + - !GMPruningModifier + params: + - re:backbone.backbone.*.*.rbr_dense.conv.weight + - re:backbone.neck.*.*.rbr_dense.conv.weight + init_sparsity: eval(init_sparsity) + final_sparsity: eval(final_sparsity) + start_epoch: eval(pruning_start_epoch) + end_epoch: eval(pruning_end_epoch) + update_frequency: eval(pruning_update_frequency) +--- diff --git a/sparsity/sparseml_hook.py b/sparsity/sparseml_hook.py new file mode 100644 index 0000000000..6462c9f12d --- /dev/null +++ b/sparsity/sparseml_hook.py @@ -0,0 +1,38 @@ +from mmseg.registry import HOOKS +from mmseg.utils.misc import calc_sparsity +from mmengine.hooks import Hook +from sparseml.pytorch.optim import ScheduledModifierManager + + +@HOOKS.register_module() +class SparseMLHook(Hook): + def __init__(self, steps_per_epoch=1488, start_epoch=50, prune_interval_epoch=2): + self.steps_per_epoch = steps_per_epoch + self.start_epoch = start_epoch + self.prune_interval_epoch = prune_interval_epoch + + def before_train(self, runner) -> None: + self.manager = ScheduledModifierManager.from_yaml(runner.cfg.recipe) + + optimizer = runner.optim_wrapper.optimizer + optimizer = self.manager.modify(runner.model.module, + optimizer, + steps_per_epoch=self.steps_per_epoch, + epoch=self.start_epoch) + runner.optim_wrapper.optimizer = optimizer + + def after_train(self, runner) -> None: + self.manager.finalize(runner.model.module) + + def after_train_iter(self, runner, batch_idx, data_batch, outputs): + if batch_idx % (self.steps_per_epoch * self.prune_interval_epoch) == 0: # 2 Epochs + calc_sparsity(runner.model.state_dict(), runner.logger) + runner.logger.info(f"Epoch #{batch_idx // self.steps_per_epoch} End") + + def after_test_epoch(self, runner, metrics): + runner.logger.info("Switching to deployment model") + # if repvgg style -> deploy + for module in runner.model.modules(): + if hasattr(module, 'switch_to_deploy'): + module.switch_to_deploy() + calc_sparsity(runner.model.state_dict(), runner.logger, True) diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index 78851f0d2a..e223ceda72 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -1,67 +1,71 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse -import logging import os import os.path as osp import torch -import torch.nn.functional as F import torch.nn as nn from mmengine.config import Config, DictAction -from mmengine.logging import print_log from mmengine.runner import Runner - from mmseg.registry import RUNNERS +from mmseg.utils.misc import calc_sparsity, load_pretrained_weights_soft import onnx from onnxsim import simplify -from mmseg.models.utils import resize - -import torch.nn.functional as F - -from collections import OrderedDict -import warnings - -def load_pretrained_weights_soft(model, checkpoint): - - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] - elif 'model' in checkpoint: - state_dict = checkpoint['model'] - else: - state_dict = checkpoint - - model_dict = model.state_dict() - new_state_dict = OrderedDict() - matched_layers, discarded_layers = [], [] - for k, v in state_dict.items(): - if k.startswith('module.'): - k = k[7:] # discard module. - if k in model_dict and model_dict[k].size() == v.size(): - new_state_dict[k] = v - matched_layers.append(k) - else: - discarded_layers.append(k) - - model_dict.update(new_state_dict) - model.load_state_dict(model_dict) - - if len(matched_layers) == 0: - warnings.warn( - 'The pretrained weights "{}" cannot be loaded, ' - 'please check the key names manually ' - '(** ignored and continue **)' - ) - else: - print('Successfully loaded pretrained weights') - if len(discarded_layers) > 0: - print( - '** The following layers are discarded ' - 'due to unmatched keys or layer size: {}'. - format(discarded_layers) - ) +def dummy_prune_ckpt(ckpt, prune_ratio=0.5, random_prune=False): + prefix = next(iter(ckpt['state_dict'])).split('backbone.stage0')[0] + for k, v in ckpt['state_dict'].items(): + if k.startswith(prefix) and k.endswith('.rbr_dense.conv.weight'): + if random_prune: # Sparsify layer randomly: + v = random_prune_layer(v, prune_ratio) + else: # Sparsify layer according to magnitude: + v = dummy_prune_layer(v, prune_ratio) + return ckpt + + +def random_prune_layer(layer, prune_ratio=0.5): + """ + Randomly prune (set to zero) a fraction of elements in a PyTorch tensor. + + Args: + layer (torch.Tensor): Input tensor of shape [B, C, H, W]. + prune_ratio (float): Fraction of elements to set to zero. + + Returns: + torch.Tensor: Pruned tensor with the same shape as the input. + """ + # Determine the number of elements to prune + num_elements = layer.numel() + num_prune = int(prune_ratio * num_elements) + + # Create a mask with zeros and ones to select the elements to prune + mask = torch.ones(num_elements) + mask[:num_prune] = 0 + mask = mask[torch.randperm(num_elements)] # Shuffle the mask randomly + mask = mask.view(layer.shape) + + # Apply the mask to the input tensor to prune it + layer *= mask + return layer + + +def dummy_prune_layer(layer, prune_ratio=0.5): + # Flatten the tensor + flattened_layer = layer.flatten() + # Get the absolute values + abs_values = torch.abs(flattened_layer) + # Get indices sorted by absolute values + sorted_indices = torch.argsort(abs_values) + # Determine the threshold index + threshold_index = int(prune_ratio * len(sorted_indices)) + # Set values below the threshold to zero + flattened_layer[sorted_indices[:threshold_index]] = 0 + # Reshape the tensor back to its original shape + pruned_tensor = flattened_layer.reshape(layer.shape) + + return pruned_tensor def parse_args(): @@ -69,10 +73,13 @@ def parse_args(): parser.add_argument('--work-dir', help='the dir to save logs and models') parser.add_argument('--checkpoint', help='checkpoint file', default=None) parser.add_argument('--no_simplify', action='store_false') - parser.add_argument('--no_postprocess', action='store_true', default=False) + parser.add_argument('--postprocess', action='store_true', default=False) parser.add_argument('--shape', nargs=2, type=int, default=[1024, 1920]) + parser.add_argument('-o', '--opset', type=int, default=13) parser.add_argument('--out_name', default='fcn.onnx', type=str, help="Name for the onnx output") - parser.add_argument('--soft_weights_loading',action='store_true', default=False) + parser.add_argument('--soft_weights_loading', action='store_true', default=False) + parser.add_argument('--dummy_prune_ratio', type=float, default=0.0, help="Applies dummy pruning with ratio") + parser.add_argument('--random_prune', action='store_true', default=False, help="Set method to prune as random (default: Minimum absolute value)") parser.add_argument( '--cfg-options', nargs='+', @@ -103,7 +110,7 @@ class ModelWithPostProc(torch.nn.Module): def __init__(self, model, args): super(ModelWithPostProc, self).__init__() self.model = model - self.post_proc_flag = not(args.no_postprocess) + self.post_proc_flag = args.postprocess self.shape = args.shape self.bilinear_resize = nn.Upsample(size=self.shape, mode='bilinear', align_corners=True) @@ -144,40 +151,53 @@ def main(): # if 'runner_type' is set in the cfg runner = RUNNERS.build(cfg) - # start training model = runner.model if args.checkpoint: ckpt = torch.load(args.checkpoint, map_location='cpu') if args.soft_weights_loading: - load_pretrained_weights_soft(model, ckpt) + if args.dummy_prune_ratio > 0.0: + ckpt = dummy_prune_ckpt(ckpt, args.dummy_prune_ratio, args.random_prune) + load_pretrained_weights_soft(model, ckpt, runner.logger) else: if 'state_dict' in ckpt: model.load_state_dict(ckpt['state_dict']) else: model.load_state_dict(ckpt) - + + runner.logger.info("Switching to deployment model") # if repvgg style -> deploy for module in model.modules(): if hasattr(module, 'switch_to_deploy'): module.switch_to_deploy() + calc_sparsity(model.state_dict(), runner.logger, True) # to onnx model.eval() + if args.postprocess: + runner.logger.info("Adding Postprocess (Resize+ArgMax) to the model") model_with_postprocess = ModelWithPostProc(model, args) model_with_postprocess.eval() - imgs = torch.zeros(1,3, args.shape[0], args.shape[1], dtype=torch.float32).to(device) + + imgs = torch.zeros(1, 3, args.shape[0], args.shape[1], dtype=torch.float32).to(device) outputs = model_with_postprocess(imgs) - torch.onnx.export(model_with_postprocess, imgs, args.out_name, input_names=['test_input'], output_names=['output'], training=torch.onnx.TrainingMode.PRESERVE, opset_version=13) - print('model saved at: ', args.out_name) + torch.onnx.export(model_with_postprocess, + imgs, args.out_name, + input_names=['test_input'], + output_names=['output'], + training=torch.onnx.TrainingMode.PRESERVE, + opset_version=args.opset) # if also simplify if args.no_simplify: model_onnx = onnx.load(args.out_name) model_simp, check = simplify(model_onnx) - onnx.save(model_simp, args.out_name[0:-5] + '_simplify.onnx') - print('model simplified saved at: ', args.out_name[0:-5] + '_simplify.onnx') + onnx.save(model_simp, args.out_name) + runner.logger.info(f"Simplified model saved at: {args.out_name}") + else: + runner.logger.info(f"Model saved at: {args.out_name}") if __name__ == '__main__': - parser = argparse.ArgumentParser(epilog='Example: CUDA_VISIBLE_DEVICES=0 python tools/pytorch2onnx.py configs/fcn/fcn8_r18_hailo.py --checkpoint work_dirs/fcn8_r18_hailo_iterbased/epoch_1.pth --out_name my_fcn_model.onnx --shape 600 800') + parser = argparse.ArgumentParser( + epilog='Example: CUDA_VISIBLE_DEVICES=0 python tools/pytorch2onnx.py configs/fcn/fcn_hailo_10classes.py --checkpoint work_dirs/fcn_hailo/iter_173760.pth --shape 736 960 --postprocess --soft_weights_loading --out_name fcn_hailo.onnx') main() diff --git a/tools/test.py b/tools/test.py index 058fdfc864..787c5200ae 100644 --- a/tools/test.py +++ b/tools/test.py @@ -2,9 +2,14 @@ import argparse import os import os.path as osp +import torch +from copy import deepcopy +from sparsity import sparseml_hook from mmengine.config import Config, DictAction from mmengine.runner import Runner +from mmseg.engine.hooks import checkpoint_hook +from mmseg.utils.misc import calc_sparsity, load_pretrained_weights_soft # TODO: support fuse_conv_bn, visualization, and format_only @@ -28,6 +33,10 @@ def parse_args(): help='directory where painted images will be saved. ' 'If specified, it will be automatically saved ' 'to the work_dir/timestamp/show_dir') + parser.add_argument( + '--deploy', + action='store_true', + help='switch model to deployment mode and calculate sparsity ratio') parser.add_argument( '--wait-time', type=float, default=2, help='the interval of show (s)') parser.add_argument( @@ -114,7 +123,16 @@ def main(): # build the runner from config runner = Runner.from_cfg(cfg) - + if args.deploy: + ckpt = torch.load(args.checkpoint, map_location='cpu') + model_deploy = deepcopy(runner.model) + load_pretrained_weights_soft(model_deploy, ckpt, runner.logger) + runner.logger.info("Calculating sparsity ratio on deployment model") + # if repvgg style -> deploy + for module in model_deploy.modules(): + if hasattr(module, 'switch_to_deploy'): + module.switch_to_deploy() + calc_sparsity(model_deploy.state_dict(), runner.logger, True) # start testing runner.test() diff --git a/tools/train.py b/tools/train.py index 10fdaa1874..88634829f8 100644 --- a/tools/train.py +++ b/tools/train.py @@ -4,12 +4,13 @@ import os import os.path as osp +from sparsity import sparseml_hook +from mmseg.engine.hooks import checkpoint_hook + from mmengine.config import Config, DictAction from mmengine.logging import print_log from mmengine.runner import Runner -from mmseg.registry import RUNNERS - def parse_args(): parser = argparse.ArgumentParser(description='Train a segmentor') @@ -20,6 +21,11 @@ def parse_args(): action='store_true', default=False, help='resume from the latest checkpoint in the work_dir automatically') + parser.add_argument('--recipe', type=str, default=None, help='Path to a sparsification recipe, ' + 'see https://github.com/neuralmagic/sparseml for more information') + parser.add_argument("--recipe-args", type=str, default=None, help = 'A json string, csv key=value string, or dictionary ' + 'containing arguments to override the root arguments ' + 'within the recipe such as learning rate or num epochs') parser.add_argument( '--amp', action='store_true', @@ -86,7 +92,8 @@ def main(): # resume training cfg.resume = args.resume - + cfg.recipe = args.recipe + cfg.recipe_args = args.recipe_args # build the runner from config if 'runner_type' not in cfg: # build the default runner