Skip to content

Commit

Permalink
Working on new checkpoint hook
Browse files Browse the repository at this point in the history
  • Loading branch information
amitklinger committed Dec 17, 2023
1 parent 32c9ba8 commit 8c9aaac
Show file tree
Hide file tree
Showing 10 changed files with 214 additions and 120 deletions.
4 changes: 2 additions & 2 deletions configs/fcn/fcn_hailo_10classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
# default hooks - logger & checkpoint configs
default_hooks = dict(

# print log every 100 iterations.
logger=dict(type='LoggerHook', interval=200, log_metric_by_epoch=False),
# print log every 400 iterations.
logger=dict(type='LoggerHook', interval=400, log_metric_by_epoch=False),

# enable the parameter scheduler.
param_scheduler=dict(type='ParamSchedulerHook'),
Expand Down
80 changes: 80 additions & 0 deletions configs/fcn/fcn_hailo_10classes_epoch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# model settings
_base_ = [
'../_base_/datasets/cityscapes10classes.py', '../_base_/default_runtime.py',
]

# optimizer
optimizer = dict(type='Adam', lr=0.001, weight_decay=1e-5)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)

# learning policy
param_scheduler = [
dict(
type='LinearLR', start_factor=0.2, begin=0, end=1),
dict(
type='CosineAnnealingLR', begin=1, end=5, eta_min=0.00001)
]

# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=5, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

# default hooks - logger & checkpoint configs
default_hooks = dict(

# print log every 100 iterations.
logger=dict(type='LoggerHook', interval=1), #, log_metric_by_epoch=False),

# enable the parameter scheduler.
param_scheduler=dict(type='ParamSchedulerHook'),

# save checkpoint every 5 epochs.
checkpoint=dict(type='CheckpointHook',
interval=1,
save_best='mIoU',
rule='greater',
max_keep_ckpts=5),
)

# tensorboard vis ('LocalVisBackend' might be redundant) save_dir='./tf_dir/<exp_name>'
visualizer = dict(type='SegLocalVisualizer',
vis_backends=[dict(type='LocalVisBackend'), dict(type='TensorboardVisBackend')],
name='visualizer')

# data preprocessing
norm_cfg = dict(type='SyncBN', requires_grad=True)
crop_size = (512, 1024)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255,
size=crop_size)

model = dict(
type='EncoderDecoder',
backbone=dict(
type='hailoFPN',
depth=0.33,
width=0.125,
bb_channels_list=[128, 256, 512, 1024],
bb_num_repeats_list=[9, 15, 21, 12],
neck_channels_list=[256, 128, 128, 256, 256, 512],
neck_num_repeats_list=[9, 12, 12, 9]),
decode_head=dict(
type='ConvHead',
in_channels=16,
channels=128,
num_convs=1,
num_classes=10,
norm_cfg=norm_cfg,
align_corners=True,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'),
infer_wo_softmax=True)
20 changes: 13 additions & 7 deletions configs/fcn/fcn_hailo_prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
]

resume = True
load_from='./work_dirs/fcn_hailo_eta1e5/iter_68448.pth'
# load_from='./work_dirs/fcn_hailo_eta1e5/iter_68448.pth' # best checkpoint path of full training (fcn_hailo_10classes). Start of pruning procedure
load_from='./work_dirs/fcn_hailo_eta1e5_eve/iter_74400.pth'

# optimizer
optimizer = dict(type='Adam', lr=0.0001, weight_decay=1e-5)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)


# runtime settings
train_cfg = dict(type='IterBasedTrainLoop', max_iters=173760, val_interval=1488) # 74400 (50 epochs), 89280 (60 epochs), 104160 (70 epochs), 119040 (80 epochs)
train_cfg = dict(type='IterBasedTrainLoop', max_iters=178560, val_interval=1488) # 74400 (50 epochs), 89280 (60 epochs), 104160 (70 epochs), 89280 (80 epochs), 173760
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

Expand All @@ -24,15 +25,20 @@

# enable the parameter scheduler.
param_scheduler=dict(type='ParamSchedulerHook'),
)

# save checkpoint every 1 epoch.
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=1488, save_best='mIoU', rule='greater',
max_keep_ckpts=5, save_begin=163680), # 2976 (2Epoches), 7440 (5 Epoches) , max_keep_ckpts=5
)
# # save checkpoint every 1 epoch.
# checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=1488, save_best='mIoU', rule='greater',
# max_keep_ckpts=5, save_begin=163680), # 2976 (2Epoches), 7440 (5 Epoches)
# )

# learning policy: taken from the recipe
# custom hooks
custom_hooks = [dict(type='SparseMLHook', interval=10, priority='NORMAL')]
sparseml_hook = dict(type='SparseMLHook', priority='NORMAL')
# sparseml_hook = dict(type='SparseMLHook', interval=10, priority='NORMAL')
ext_checkpoint_hook = dict(type='ExtCheckpointHook', by_epoch=False, interval=1488, save_best='mIoU', rule='greater',
max_keep_ckpts=5, save_begin=163680) # 2976 (2Epoches), 7440 (5 Epoches), 80352 (54), 83328 (56), 163680
custom_hooks = [sparseml_hook, ext_checkpoint_hook]

# tensorboard vis ('LocalVisBackend' might be redundant) save_dir='./tf_dir/<exp_name>'
visualizer = dict(type='SegLocalVisualizer',
Expand Down
19 changes: 19 additions & 0 deletions mmseg/engine/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from mmengine.hooks import CheckpointHook
from mmseg.registry import HOOKS


@HOOKS.register_module()
class ExtCheckpointHook(CheckpointHook):
# def __init__(self):
# self.by_epoch = False

def after_val_epoch(self, runner, metrics):
if runner.iter == self.save_begin:
runner.logger.info('Resetting best_score to 0.0')
runner.message_hub.update_info('best_score', 0.0)
runner.message_hub.pop_info('best_ckpt', None)
if (runner.iter + 1 >= self.save_begin):
runner.logger.info('ExtCheckpointHook ExtCheckpointHook ExtCheckpointHook')
runner.logger.info(
f'Saving checkpoint at iter {runner.iter}')
super().after_val_epoch(runner, metrics)
60 changes: 59 additions & 1 deletion mmseg/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import torch
import torch.nn.functional as F

from collections import OrderedDict
from .typing_utils import SampleList


Expand Down Expand Up @@ -116,3 +116,61 @@ def stack_batch(inputs: List[torch.Tensor],
pad_shape=pad_img.shape[-2:]))

return torch.stack(padded_inputs, dim=0), padded_samples


def load_pretrained_weights_soft(model, checkpoint, logger):

if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint

model_dict = model.state_dict()
new_state_dict = OrderedDict()
matched_layers, discarded_layers = [], []

for k, v in state_dict.items():
if k.startswith('module.'):
k = k[7:] # discard module.

if k in model_dict and model_dict[k].size() == v.size():
new_state_dict[k] = v
matched_layers.append(k)
else:
discarded_layers.append(k)

model_dict.update(new_state_dict)
model.load_state_dict(model_dict)

if len(matched_layers) == 0:
logger.warning(
'The pretrained weights cannot be loaded, '
'please check the key names manually '
)
else:
logger.info('Successfully loaded pretrained weights')
if len(discarded_layers) > 0:
logger.warning(
'** The following layers are discarded '
'due to unmatched keys or layer size: {}'.
format(discarded_layers)
)
return


def calc_sparsity(model_dict, logger, verbose=False):
weights_layers_num, total_weights, total_zeros = 0, 0, 0
prefix = next(iter(model_dict)).split('backbone.stage0')[0]
for k, v in model_dict.items():
if k.startswith(prefix) and k.endswith('weight'):
weights_layers_num += 1
total_weights += v.numel()
total_zeros += (v.numel() - v.count_nonzero())
zeros_ratio = (v.numel() - v.count_nonzero()) / v.numel() * 100.0
if verbose:
logger.info(f"[{weights_layers_num:>2}] {k:<51}:: {v.numel() - v.count_nonzero():<5} / {v.numel():<7}"
f" ({zeros_ratio:<4.1f}%) are zeros")
logger.info(f"Model has {weights_layers_num} weight layers")
logger.info(f"Overall Sparsity is roughly: {100 * total_zeros / total_weights:.1f}%")
30 changes: 1 addition & 29 deletions recipes/recipe_yolox_hailo_pruning.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ init_sparsity: 0.01
final_sparsity: 0.60
pruning_start_epoch: 60
pruning_end_epoch: 110
pruning_update_frequency: 5.0
pruning_update_frequency: 2.0

#Modifiers
training_modifiers:
Expand All @@ -38,31 +38,3 @@ pruning_modifiers:
end_epoch: eval(pruning_end_epoch)
update_frequency: eval(pruning_update_frequency)
---

training_modifiers:
- !EpochRangeModifier
start_epoch: 0
end_epoch: eval(num_epochs)

- !LearningRateFunctionModifier
start_epoch: 3
end_epoch: eval(num_epochs)
lr_func: linear
init_lr: eval(init_lr)
final_lr: eval(final_lr)

- !LearningRateFunctionModifier
start_epoch: 0
end_epoch: 3
lr_func: linear
init_lr: eval(weights_warmup_lr)
final_lr: eval(init_lr)
param_groups: [0, 1]

- !LearningRateFunctionModifier
start_epoch: 0
end_epoch: 3
lr_func: linear
init_lr: eval(biases_warmup_lr)
final_lr: eval(init_lr)
param_groups: [2]
35 changes: 15 additions & 20 deletions sparsity/sparseml_hook.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,38 @@
from mmseg.registry import RUNNERS, HOOKS
from mmseg.registry import HOOKS
from mmseg.utils.misc import calc_sparsity
from mmengine.hooks import Hook
from sparseml.pytorch.optim import ScheduledModifierManager


@HOOKS.register_module()
class SparseMLHook(Hook):
def __init__(self, interval=10):
self.interval = interval
def __init__(self, steps_per_epoch=1488, start_epoch=50, prune_interval_epoch=2):
self.steps_per_epoch = steps_per_epoch
self.start_epoch = start_epoch
self.prune_interval_epoch = prune_interval_epoch

def before_train(self, runner) -> None:
self.manager = ScheduledModifierManager.from_yaml(runner.cfg.recipe)

optimizer = runner.optim_wrapper.optimizer
optimizer = self.manager.modify(runner.model.module, optimizer, steps_per_epoch=1488, epoch=40)
optimizer = self.manager.modify(runner.model.module,
optimizer,
steps_per_epoch=self.steps_per_epoch,
epoch=self.start_epoch)
runner.optim_wrapper.optimizer = optimizer

def after_train(self, runner) -> None:
self.manager.finalize(runner.model.module)

def after_train_iter(self, runner, batch_idx, data_batch, outputs):
if batch_idx % (1488 * 2) == 0: # 2 Epochs
runner.logger.info(f"Epoch #{batch_idx // 1488} End")
self._calc_sparsity(runner.model.state_dict(), runner.logger)
if batch_idx % (self.steps_per_epoch * self.prune_interval_epoch) == 0: # 2 Epochs
calc_sparsity(runner.model.state_dict(), runner.logger)
runner.logger.info(f"Epoch #{batch_idx // self.steps_per_epoch} End")

def after_test_epoch(self, runner, metrics):
runner.logger.info("Switching to deployment model")
# if repvgg style -> deploy
for module in runner.model.modules():
if hasattr(module, 'switch_to_deploy'):
module.switch_to_deploy()
self._calc_sparsity(runner.model.state_dict(), runner.logger)

def _calc_sparsity(self, model_dict, logger):
weights_layers_num, total_weights, total_zeros = 0, 0, 0
prefix = next(iter(model_dict)).split('backbone.stage0')[0]
for k, v in model_dict.items():
if k.startswith(prefix) and k.endswith('weight'):
weights_layers_num += 1
total_weights += v.numel()
total_zeros += (v.numel() - v.count_nonzero())
logger.info(f"Model has {weights_layers_num} weight layers")
logger.info(f"Overall Sparsity is roughly: {100 * total_zeros / total_weights:.1f}%")

calc_sparsity(runner.model.state_dict(), runner.logger, True)
Loading

0 comments on commit 8c9aaac

Please sign in to comment.