Skip to content

Commit

Permalink
Merge pull request #8 from amitklinger/Hailo-2.0
Browse files Browse the repository at this point in the history
Hailo 2.0: Pruning Support + Best pruned checkpoint save
  • Loading branch information
HailoModelZooValidation authored Dec 19, 2023
2 parents 6b68c9b + e1627ad commit 5d616e3
Show file tree
Hide file tree
Showing 10 changed files with 297 additions and 95 deletions.
2 changes: 1 addition & 1 deletion configs/_base_/datasets/cityscapes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# dataset settings
dataset_type = 'CityscapesDataset'
data_root = '/data/cityscapes/'
data_root = '/data/cityscapes10classes/'
crop_size = (512, 1024)
train_pipeline = [
dict(type='LoadImageFromFile'),
Expand Down
20 changes: 11 additions & 9 deletions configs/fcn/fcn_hailo_10classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,32 @@
dict(
type='LinearLR', start_factor=0.2, by_epoch=False, begin=0, end=7440),
dict(
type='CosineAnnealingLR', begin=7440, by_epoch=False, end=59520)
type='CosineAnnealingLR', begin=7440, end=74400, eta_min=0.00001, by_epoch=False)
]

# runtime settings
train_cfg = dict(type='IterBasedTrainLoop', max_iters=59520, val_interval=1488)
train_cfg = dict(type='IterBasedTrainLoop', max_iters=74400, val_interval=1488)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

# default hooks - logger & checkpoint configs
default_hooks = dict(

# print log every 100 iterations.
logger=dict(type='LoggerHook', interval=100, log_metric_by_epoch=False),
# print log every 400 iterations.
logger=dict(type='LoggerHook', interval=400, log_metric_by_epoch=False),

# enable the parameter scheduler.
param_scheduler=dict(type='ParamSchedulerHook'),

# save checkpoint every 5 epochs.
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=7440),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=7440, save_best='mIoU', rule='greater',
max_keep_ckpts=5),
)

# tensorboard vis
vis_backends = [dict(type='LocalVisBackend'),
dict(type='TensorboardVisBackend')]
# tensorboard vis ('LocalVisBackend' might be redundant) save_dir='./tf_dir/<exp_name>'
visualizer = dict(type='SegLocalVisualizer',
vis_backends=[dict(type='LocalVisBackend'), dict(type='TensorboardVisBackend')],
name='visualizer')

# data preprocessing
norm_cfg = dict(type='SyncBN', requires_grad=True)
Expand Down Expand Up @@ -72,4 +74,4 @@
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'),
infer_wo_softmax=True)
infer_wo_softmax=True)
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,42 @@
'../_base_/datasets/cityscapes10classes.py', '../_base_/default_runtime.py',
]

resume = True
# best checkpoint path of full training (fcn_hailo_10classes). Start of pruning procedure:
load_from = './work_dirs/fcn_hailo_eta1e5_eve/iter_74400.pth'

# optimizer
optimizer = dict(type='Adam', lr=0.001, weight_decay=1e-5)
optimizer = dict(type='Adam', lr=0.0001, weight_decay=1e-5)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)

# learning policy
param_scheduler = [
dict(
type='LinearLR', start_factor=0.2, by_epoch=False, begin=0, end=7440),
dict(
type='CosineAnnealingLR', begin=7440, by_epoch=False, end=59520)
]

# runtime settings
train_cfg = dict(type='IterBasedTrainLoop', max_iters=59520, val_interval=1488)
train_cfg = dict(type='IterBasedTrainLoop', max_iters=178560, val_interval=1488) # 74400 (50 epochs), 178560 (120)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

# default hooks - logger & checkpoint configs
default_hooks = dict(

# print log every 100 iterations.
logger=dict(type='LoggerHook', interval=100, log_metric_by_epoch=False),
# print log every 500 iterations.
logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=False),

# enable the parameter scheduler.
param_scheduler=dict(type='ParamSchedulerHook'),
)

# save checkpoint every 5 epochs.
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=7440),
)
# learning policy: taken from the recipe
# custom hooks
sparseml_hook = dict(type='SparseMLHook', priority='NORMAL')
# Saving best checkpoint starts after pruning hits final ratio
ext_checkpoint_hook = dict(type='ExtCheckpointHook', by_epoch=False, interval=1488, save_best='mIoU', rule='greater',
max_keep_ckpts=5, save_begin=163680) # 163680 (110 epochs)
custom_hooks = [sparseml_hook, ext_checkpoint_hook]

# tensorboard vis
vis_backends = [dict(type='LocalVisBackend'),
dict(type='TensorboardVisBackend')]
# tensorboard vis ('LocalVisBackend' might be redundant) save_dir='./tf_dir/<exp_name>'
visualizer = dict(type='SegLocalVisualizer',
vis_backends=[dict(type='LocalVisBackend'), dict(type='TensorboardVisBackend')],
name='visualizer')

# data preprocessing
norm_cfg = dict(type='SyncBN', requires_grad=True)
Expand All @@ -60,7 +63,7 @@
neck_channels_list=[256, 128, 128, 256, 256, 512],
neck_num_repeats_list=[9, 12, 12, 9]),
decode_head=dict(
type='PostProcess',
type='ConvHead',
in_channels=16,
channels=128,
num_convs=1,
Expand All @@ -72,4 +75,4 @@
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'),
infer_wo_softmax=True)
infer_wo_softmax=True)
16 changes: 16 additions & 0 deletions mmseg/engine/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from mmengine.hooks import CheckpointHook
from mmseg.registry import HOOKS


@HOOKS.register_module()
class ExtCheckpointHook(CheckpointHook):

def after_val_epoch(self, runner, metrics):
if runner.iter == self.save_begin:
runner.logger.info('Resetting best_score to 0.0')
runner.message_hub.update_info('best_score', 0.0)
runner.message_hub.pop_info('best_ckpt', None)
if (runner.iter + 1 >= self.save_begin):
runner.logger.info(
f'Saving checkpoint at iter {runner.iter}')
super().after_val_epoch(runner, metrics)
60 changes: 59 additions & 1 deletion mmseg/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import torch
import torch.nn.functional as F

from collections import OrderedDict
from .typing_utils import SampleList


Expand Down Expand Up @@ -116,3 +116,61 @@ def stack_batch(inputs: List[torch.Tensor],
pad_shape=pad_img.shape[-2:]))

return torch.stack(padded_inputs, dim=0), padded_samples


def load_pretrained_weights_soft(model, checkpoint, logger):

if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint

model_dict = model.state_dict()
new_state_dict = OrderedDict()
matched_layers, discarded_layers = [], []

for k, v in state_dict.items():
if k.startswith('module.'):
k = k[7:] # discard module.

if k in model_dict and model_dict[k].size() == v.size():
new_state_dict[k] = v
matched_layers.append(k)
else:
discarded_layers.append(k)

model_dict.update(new_state_dict)
model.load_state_dict(model_dict)

if len(matched_layers) == 0:
logger.warning(
'The pretrained weights cannot be loaded, '
'please check the key names manually '
)
else:
logger.info('Successfully loaded pretrained weights')
if len(discarded_layers) > 0:
logger.warning(
'** The following layers are discarded '
'due to unmatched keys or layer size: {}'.
format(discarded_layers)
)
return


def calc_sparsity(model_dict, logger, verbose=False):
weights_layers_num, total_weights, total_zeros = 0, 0, 0
prefix = next(iter(model_dict)).split('backbone.stage0')[0]
for k, v in model_dict.items():
if k.startswith(prefix) and k.endswith('weight'):
weights_layers_num += 1
total_weights += v.numel()
total_zeros += (v.numel() - v.count_nonzero())
zeros_ratio = (v.numel() - v.count_nonzero()) / v.numel() * 100.0
if verbose:
logger.info(f"[{weights_layers_num:>2}] {k:<51}:: {v.numel() - v.count_nonzero():<5} / {v.numel():<7}"
f" ({zeros_ratio:<4.1f}%) are zeros")
logger.info(f"Model has {weights_layers_num} weight layers")
logger.info(f"Overall Sparsity is roughly: {100 * total_zeros / total_weights:.1f}%")
40 changes: 40 additions & 0 deletions recipes/recipe_yolox_hailo_pruning.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@

---

version: 1.1.0

# General Hyperparams
start_epoch: 50
num_epochs: 120
init_lr: 0.00001
final_lr: 0.00001
weights_warmup_lr: 0
biases_warmup_lr: 0

# Pruning Hyperparams
init_sparsity: 0.01
final_sparsity: 0.60
pruning_start_epoch: 60
pruning_end_epoch: 110
pruning_update_frequency: 2.0

#Modifiers
training_modifiers:
- !LearningRateFunctionModifier
start_epoch: eval(start_epoch)
end_epoch: eval(num_epochs)
lr_func: linear
init_lr: eval(init_lr)
final_lr: eval(init_lr)

pruning_modifiers:
- !GMPruningModifier
params:
- re:backbone.backbone.*.*.rbr_dense.conv.weight
- re:backbone.neck.*.*.rbr_dense.conv.weight
init_sparsity: eval(init_sparsity)
final_sparsity: eval(final_sparsity)
start_epoch: eval(pruning_start_epoch)
end_epoch: eval(pruning_end_epoch)
update_frequency: eval(pruning_update_frequency)
---
38 changes: 38 additions & 0 deletions sparsity/sparseml_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from mmseg.registry import HOOKS
from mmseg.utils.misc import calc_sparsity
from mmengine.hooks import Hook
from sparseml.pytorch.optim import ScheduledModifierManager


@HOOKS.register_module()
class SparseMLHook(Hook):
def __init__(self, steps_per_epoch=1488, start_epoch=50, prune_interval_epoch=2):
self.steps_per_epoch = steps_per_epoch
self.start_epoch = start_epoch
self.prune_interval_epoch = prune_interval_epoch

def before_train(self, runner) -> None:
self.manager = ScheduledModifierManager.from_yaml(runner.cfg.recipe)

optimizer = runner.optim_wrapper.optimizer
optimizer = self.manager.modify(runner.model.module,
optimizer,
steps_per_epoch=self.steps_per_epoch,
epoch=self.start_epoch)
runner.optim_wrapper.optimizer = optimizer

def after_train(self, runner) -> None:
self.manager.finalize(runner.model.module)

def after_train_iter(self, runner, batch_idx, data_batch, outputs):
if batch_idx % (self.steps_per_epoch * self.prune_interval_epoch) == 0: # 2 Epochs
calc_sparsity(runner.model.state_dict(), runner.logger)
runner.logger.info(f"Epoch #{batch_idx // self.steps_per_epoch} End")

def after_test_epoch(self, runner, metrics):
runner.logger.info("Switching to deployment model")
# if repvgg style -> deploy
for module in runner.model.modules():
if hasattr(module, 'switch_to_deploy'):
module.switch_to_deploy()
calc_sparsity(runner.model.state_dict(), runner.logger, True)
Loading

0 comments on commit 5d616e3

Please sign in to comment.