Skip to content

Commit

Permalink
fix bug for rtmdet
Browse files Browse the repository at this point in the history
  • Loading branch information
yzbx committed Jan 31, 2023
1 parent 0926cc1 commit ce78026
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 9 deletions.
14 changes: 8 additions & 6 deletions mmyolo/engine/hooks/ymir_training_monitor_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

from mmengine.hooks import Hook
from mmengine.registry import HOOKS
from ymir_exc.util import (get_merged_config, write_ymir_monitor_process, write_ymir_training_result)
from ymir_exc.util import (get_merged_config, write_ymir_monitor_process,
write_ymir_training_result)


@HOOKS.register_module()
Expand Down Expand Up @@ -46,7 +47,7 @@ def after_val_epoch(self, runner, metrics: Optional[Dict[str, float]] = None) ->
"""
if runner.rank in [0, -1]:
N = len('coco/bbox_')
evaluation_result = {key[N:]: value for key, value in metrics.items()}
evaluation_result = {key[N:]: value for key, value in metrics.items()} # type: ignore
out_dir = self.ymir_cfg.ymir.output.models_dir
cfg_files = glob.glob(osp.join(out_dir, '*.py'))

Expand All @@ -64,11 +65,12 @@ def after_val_epoch(self, runner, metrics: Optional[Dict[str, float]] = None) ->
else:
warnings.warn(f'no best checkpoint found on {runner.epoch}')

last_ckpts = glob.glob(osp.join(out_dir, '*.pth'))
if len(last_ckpts) > 0:
logging.info(f'epoch={runner.epoch}, save {newest_best_ckpt} to result.yaml')
latest_ckpts = glob.glob(osp.join(out_dir, '*.pth'))
if len(latest_ckpts) > 0:
last_ckpt = max(latest_ckpts, key=osp.getctime)
logging.info(f'epoch={runner.epoch}, save {last_ckpt} to result.yaml')
write_ymir_training_result(self.ymir_cfg,
files=last_ckpts + cfg_files,
files=[last_ckpt] + cfg_files,
id='last',
evaluation_result=evaluation_result)
else:
Expand Down
4 changes: 3 additions & 1 deletion ymir/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ RUN cd /app && \
mkdir -p /img-man && \
mv ymir/img-man/*.yaml /img-man && \
mkdir /weights && \
mv ymir/weights/*.pth /weights
mv ymir/weights/*.pth /weights && \
mkdir -p /root/.cache/torch/hub/checkpoints && \
mv ymir/weights/imagenet/*.pth /root/.cache/torch/hub/checkpoints

ENV PYTHONPATH=.
WORKDIR /app
Expand Down
10 changes: 8 additions & 2 deletions ymir/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def recursive_modify_attribute(mmengine_cfgdict: Union[Config, ConfigDict], attr
mmengine_cfg.train_num_workers = workers_per_gpu
mmengine_cfg.train_dataloader.batch_size = samples_per_gpu
mmengine_cfg.train_dataloader.num_workers = workers_per_gpu
mmengine_cfg.optim_wrapper.optimizer.batch_size_per_gpu = samples_per_gpu

if 'batch_size_per_gpu' in mmengine_cfg.optim_wrapper.optimizer:
mmengine_cfg.optim_wrapper.optimizer.batch_size_per_gpu = samples_per_gpu

# modify model output channel
num_classes = len(ymir_cfg.param.class_names)
Expand Down Expand Up @@ -91,8 +93,12 @@ def recursive_modify_attribute(mmengine_cfgdict: Union[Config, ConfigDict], attr
max_epochs = int(ymir_cfg.param.max_epochs)
mmengine_cfg.train_cfg.max_epochs = max_epochs
mmengine_cfg.max_epochs = max_epochs
mmengine_cfg.default_hooks.param_scheduler.max_epochs = max_epochs

param_scheduler_type = mmengine_cfg.default_hooks.param_scheduler.type
if param_scheduler_type == 'YOLOv5ParamSchedulerHook':
mmengine_cfg.default_hooks.param_scheduler.max_epochs = max_epochs
elif param_scheduler_type == 'PPYOLOEParamSchedulerHook':
mmengine_cfg.default_hooks.param_scheduler.total_epochs = max_epochs
# modify checkpoint
mmengine_cfg.default_hooks.checkpoint['out_dir'] = ymir_cfg.ymir.output.models_dir

Expand Down
4 changes: 4 additions & 0 deletions ymir/weights/download_weights.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,7 @@ wget https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_s_fast_8xb8-8
wget https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_m_fast_8xb8-80e_coco/ppyoloe_plus_m_fast_8xb8-80e_coco_20230104_193132-e4325ada.pth
wget https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_l_fast_8xb8-80e_coco/ppyoloe_plus_l_fast_8xb8-80e_coco_20230102_203825-1864e7b3.pth

# rtmdet imagenet pretrain, copy to /root/.cache/torch/hub/checkpoints
cd imagenet
wget https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-s_imagenet_600e.pth
wget https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-tiny_imagenet_600e.pth

0 comments on commit ce78026

Please sign in to comment.