Proto2Proto [arxiv]
conda env create -f environment.yml -n myenv python=3.6
conda activate myenv
- Refer https://github.com/M-Nauta/ProtoTree to download and preprocess cars dataset
- For augmentation, run lib/protopnet/cars_augment.py (Change the dataset paths if required)
- Create a symbolic link to the dataset folder as datasets
- We need the dataset paths as follows
trainDir: datasets/cars/train_augmented # Path-to-dataset
projectDir: datasets/cars/train # Path-to-dataset
testDir: datasets/cars/test # Path-to-dataset
sh train_teacher.sh # For teacher training
sh train_baseline.sh # For baseline student training
sh train_kd.sh # For proto2proto student training
NOTE: For proto2proto student training, set the teacher path in Experiments/Resnet50_18_cars/kd_Resnet50_18/args.yaml: backbone.loadPath. Use the teacher model trained previously. For eg.
loadPath: Experiments/Resnet50_18_cars/teacher_Resnet50/org/models/protopnet_xyz.pth
Set model paths in Experiments/Resnet50_18_cars/eval_setting/args.yaml: Teacherbackbone.loadPath, StudentBaselinebackbone.loadPath, StudentKDbackbone.loadPath. And Run
sh eval_setting.sh
- Dataset path should be set appropriately
- Model path should be set in KD (1 place) and eval setting (3 places)
- Set CUDA_VISIBLE_DEVICES depending on the GPUs, change batchSize if required
Our code base is build on top of ProtoPNet
If you use our work in your research please cite us:
@inproceedings{Keswani2022Proto2ProtoCY,
title={Proto2Proto: Can you recognize the car, the way I do?},
author={Monish Keswani and Sriranjani Ramakrishnan and Nishant Reddy and Vineeth N. Balasubramanian},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR 2022)},
eprint={2204.11830},
archivePrefix={arXiv},
year={2022}
}