FiT: Parameter Efficient Few-shot Transfer Learning for Personalized and Federated Image Classification
This repository contains the code to reproduce the experiments carried out in: FiT: Parameter Efficient Few-shot Transfer Learning for Personalized and Federated Image Classification
This code requires the following:
- Python 3.8 or greater
- PyTorch 1.11 or greater (most of the code is written in PyTorch)
- TensorFlow 2.8 or greater (for reading VTAB datasets)
- TensorFlow Datasets 4.5.2 or greater (for reading VTAB datasets)
- gsutil (for downloading the The Quick, Draw! dataset)
- The majority of the experiments in the paper are executed on a single NVIDIA A100 GPU with 80 GB of memory. By reducing the batch size, it is possible to run on a GPU with less memory, but classification results may be different.
The following steps will take a considerable length of time and disk space.
-
Clone or download this repository.
-
The VTAB-v2 benchmark uses TensorFlow Datasets. The majority of these are downloaded and pre-processed upon first use. However, the Diabetic Retinopathy and Resisc45 datasets need to be downloaded manually. Click on the links for details.
-
Switch to the
src
directory in this repo and download the BiT pretrained model:wget https://storage.googleapis.com/bit_models/BiT-M-R50x1.npz
-
For the federated learning experiments on The Quick, Draw! dataset, download the dataset as follows:
mkdir quickdraw-npy gsutil -m cp gs://quickdraw_dataset/full/numpy_bitmap/*.npy quickdraw-npy
Switch to the src
directory in this repo and execute any of the commands below.
1-shot:
python run_fit.py --classifier <qda, lda, or protonets> --examples_per_class 1 -i 0 --mode few_shot -c <path to checkpoint directory> --download_path_for_tensorflow_datasets <path to where you want the TensorFlow Datasets downloaded>
> 1-shot:
python run_fit.py --classifier <qda, lda, or protonets> --examples_per_class <2-10, or -1 for all> --mode few_shot -c <path to checkpoint directory> --download_path_for_tensorflow_datasets <path to where you want the TensorFlow Datasets downloaded>
python run_fit.py --classifier <qda, lda, or protonets> --mode vtab_1000 --do_not_split -c <path to checkpoint directory> --download_path_for_tensorflow_datasets <path to where you want the TensorFlow Datasets downloaded>
python run_fed_avg.py --data_path <path to dataset> --checkpoint_dir <path to checkpoint directory> \
--num_local_epochs <number of local updates> --iterations <number communication rounds> \
--num_clients <number of classes> --num_classes <number of classes per client> \
--shots_per_client <shots per client> --dataset <quickdraw, cifar100> --use_npy_data
Alternatively, for CIFAR100 the bash script can be used:
bash fed_avg_cifar100.sh $num_clients $num_shots_per_client $data_path $checkpoint_dir
and for QuickDraw:
bash fed_avg_quickdraw.sh $num_clients $num_shots_per_client $data_path $checkpoint_dir
Other hyperparameters in these scripts are set to the values used for the federated learning experiments in the paper.
To ask questions or report issues, please open an issue on the issues tracker.
If you use this code, please cite our paper.
@inproceedings{shysheya2022fit,
title={FiT: Parameter Efficient Few-shot Transfer Learning for Personalized and Federated Image Classification},
author={Shysheya, Aliaksandra and Bronskill, John and Patacchiola, Massimiliano and Nowozin, Sebastian and Turner, Richard E.},
journal={arXiv preprint arXiv:2206.08671},
year={2022}
}