Label Smoothing Plus: "LS+: Informed Label Smoothing for Improving Calibration in Medical Image Classification"
This repository contains the code and pretrained models for LS+: Informed Label Smoothing for Improving Calibration in Medical Image Classification, which has been accepted in MICCAI 2024.
If the code or the paper has been useful in your research, please add a citation to our work:
@article{lsplus,
title={LS+: Informed Label Smoothing for Improving Calibration in Medical Image Classification},
author={Sambyal, Abhishek Singh and Niyaz, Usma and Shrivastava, Saksham and Krishnan, Narayanan C and Bathula, Deepti R.},
booktitle={Medical Image Computing and Computer Assisted Intervention (MICCAI)},
year={2024}
}
The code is based on Tensorflow and requires a few further dependencies, listed in tf.yml and tf1.yml. Please create these two conda envirnoments using the following command:
conda env create -f tf.yml
conda env create -f tf1.yml
├── ls_plus/
├── 01_retention_curves.ipynb
├── 02_all_histogram_plots/
├── 02_histogram_viz.ipynb
├── baselines.yml
├── chaoyang-data/
├── d1_ablation/
├── d1_chaoyang_code/
├── d2_ablation/
├── d2_mhist_code/
├── d3_skin_code/
├── distiller.yml
├── ISIC_2018/
├── MHIST/
├── r34_retention_curves.png
├── r50_retention_curves.png
└── README.md
Datasets can be downloaded from here:
- Chaoyang: https://bupt-ai-cz.github.io/HSA-NRL/
- MHIST: https://bmirds.github.io/MHIST/
- ISIC: https://challenge.isic-archive.com/data/#2018
- Chaoyang: Copy
train/
andtest/
folders from the downloaded Chaoyang dataset into thechaoyang-data/
directory. - MHIST: Copy
images/
folder from downloaded MHIST dataset toMHIST/
directory. - ISIC 2018: Copy all
ISIC2018_Task3_*/
folders from downloaded ISIC dataset intoISIC_2018/
directory.
Link to all pretrained models will be available soon.
You should use bash_eval.py
shell script to train different approaches on different datasets available inside each code directories d1_chaoyang_code/
, d2_mhist_code/
, d3_skin_code/
Example [Python command to run baseline (HL)]:
python baseline.py -epochs 200 -test 0 -model resnet34 -bestmodelpath resnet34/1/vanilla_best_model.hdf5 -gpu 0 -csvfilename resnet34/metrics.csv;
--model: model to train (resnet34/resnet50)
--bestmodelpath: path to save trained model (or directory path)
--gpu: choose gpu number to run you code. Default: 0.
--csvfilename: CSV filename to store all the metrics values.
- Please check
bash_eval.py
script for the correct command pertaining to the method. -test 0
will train the model and run the evaluation code to generate outputs (metrics, plots).
If you have any questions or doubts, please feel free to open an issue in this repository or reach out to us at the email addresses provided in the paper.
The tables below contains the calibration results obtained after appliying temperature scaling on the models mentioned in the paper. These are additional results and are not added in the paper.
ResNet34, Chaoyang | ResNet50, Chaoyang | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
ResNet34, MHIST | ResNet50, MHIST | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
In the above tables, Hard Labels (HL), Label Smoothing (LS), FL (FL-3 denotes focal loss with gamma = 3
), Difference between Confidence and Accuracy (DCA), Multi-Class Difference in Confidence and Accuracy (MDCA) and Ours (LS+).