This file will showcase how to reproduce our work.
python 3.6+ (tested in 3.7)
PyTorch 1.1 or higher (tested in 1.2 & 1.3)
CUDA 9.0 or higher (tested in 10.0)
Linux (tested in Ubuntu 18.04)
apex (see NVIDIA/apex)
DALI (see NVIDIA/DALI)
You can use the following commands to train a classifcation network.
We are training on 8 Tesla V100 GPUs. If you have 4, change "--nproc_per_node=8" to "--nproc_per_node=4".
For more detail parameters, please see main.py and main_mobile.py files.
# For normal networks, like ResNet.
python3 -m torch.distributed.launch --nproc_per_node=8 imagenet.py -a fp10_resnet50 --b 32
Our detection codes are based on mmdetection framework. Thanks to mmdetection. For more details, please see mmdetection github.
# For training.
./tools/dist_train.sh local_configs/{config_file_name}.py 8
# For testing
python3 tools/test.py local_configs/{config_file_name}.py
work_dirs/{model_path}/epoch_24.pth --gpus 8 --out work_dirs/{save_path}/{results_name}.pkl --eval bbox
We also provide a series of related tools, such as visualization, analysis, and parameters/flops counter.
- For parameters/flops counter, please see count_Param.py.