This is unofficial implementation of MeanTeacher that purposes to reproduce results of Semi-Supervised Learning method for Image Classification problem.
conda create -n ssl python=3.7
conda activate ssl
conda install pytorch==1.4.0 torchvision==0.5.0 -c pytorch
pip install -r requirements.txt
Cifar:
python data/download_cifar10.py --data_dir datadir
Train on CIFAR10
python main.py \
--data_dir datadir/cifar10/ \
--dataset cifar10 \
--labels data/cifar10_labels/4000_balanced_labels/$trial.txt \
--model_arch shake_resnet26 \
--weight_decay 2e-4 \
--unp_weight 100 \
--rampup_length 5 \
--batch_size 512 \
--labeled_batch_size 124 \
--epochs 180 \
--ema_decay 0.97 \
--nesterov True \
--lr_rampdown_length 210 \
--device_ids 0python main.py \
--data_dir datadir/cifar10/ \
--dataset cifar10 \
--labels data/cifar10_labels/4000_balanced_labels/$trial.txt \
--model_arch shake_resnet26 \
--weight_decay 2e-4 \
--unp_weight 100 \
--rampup_length 5 \
--batch_size 512 \
--labeled_batch_size 124 \
--epochs 180 \
--ema_decay 0.97 \
--nesterov True \
--lr_rampdown_length 210 \
--device_ids 0python main.py \
--data_dir datadir/cifar10/ \
--dataset cifar10 \
--labels data/cifar10_labels/4000_balanced_labels/00.txt \
--model_arch shake_resnet26 \
--unp_weight 100 \
--rampup_length 5 \
--labeled_batch_size 62 \
--batch_size 256 \
--epochs 180 \
--device_ids 0
Distributed training on CIFAR10
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=2 --use_env main.py \
--data_dir datadir/cifar10/ \
--dataset cifar10 \
--labels data/cifar10_labels/4000_balanced_labels/$trial.txt \
--model_arch shake_resnet26 \
--weight_decay 2e-4 \
--unp_weight 100 \
--rampup_length 5 \
--batch_size 512 \
--labeled_batch_size 124 \
--epochs 180 \
--ema_decay 0.97 \
--nesterov True \
--lr_rampdown_length 210 \
--device_ids 0python main.py \
--data_dir datadir/cifar10/ \
--dataset cifar10 \
--labels data/cifar10_labels/4000_balanced_labels/$trial.txt \
--model_arch shake_resnet26 \
--weight_decay 2e-4 \
--unp_weight 100 \
--rampup_length 5 \
--batch_size 512 \
--labeled_batch_size 124 \
--epochs 180 \
--ema_decay 0.97 \
--nesterov True \
--lr_rampdown_length 210 \
--device_ids 0python main.py \
--data_dir datadir/cifar10/ \
--dataset cifar10 \
--labels data/cifar10_labels/4000_balanced_labels/00.txt \
--model_arch shake_resnet26 \
--unp_weight 100 \
--rampup_length 5 \
--labeled_batch_size 62 \
--batch_size 256 \
--epochs 180 \
--use_ddp
Evalate on CIFAR10
python test.py --data_dir datadir/cifar10/ \
--batch_size 20 \
--resume weights/cifar10/meanteacher/best_model.ckpt \
--test_set val \
--device_ids 0
Do 10 runs on 1000-labeled CIFAR10:
bash cifar10_1k_experiments.sh
Do 10 runs on 4000-labeled CIFAR10:
bash cifar10_4k_experiments.sh
Average results over 10 runs on CIFAR10 with 4k labeled images:
student_top1 | teacher_top1 | student_top5 | teacher_top5 | |
---|---|---|---|---|
my acc | 91.12 ± .54 | 91.59 ± .36 | 99.63 ± .07 | 99.72 ± .05 |
my err | 8.88 ± .54 | 8.41 ± .36 | .37 ± .07 | .28 ± .05 |
paper err | - | 6.28 ± .15 | - | - |
[1] Tarvainen, Antti, and Harri Valpola. "Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results." arXiv preprint arXiv:1703.01780 (2017). [paper][repo]
[2] pytorch_ema