This repository provides the official PyTorch implementation of our NeurIPS 2022 paper:
Test-Time Prompt Tuning for Zero-shot Generalization in Vision-Language Models
Authors: Manli Shu, Weili Nie, De-An Huang, Tom Goldstein, Anima Anandkumar, Chaowei Xiao
For more details, please check out our project page and paper.
This repository contains the implementation of TPT for image classification with a pre-trained CLIP. We consider 3 different initializations for test-time prompt tuning:
- Using a hand-crafted prompt as initialization (e.g., "a photo of a ___")
- Using a learned soft prompt (CoOp) as initialization.
- Using the output of a trained conditional prompt learner (CoCoOp) as initialization.
This implementation is for the single-GPU configuration.
To evaluate on ImageNet, ImageNet-V2, and ImageNet-Sketch (which has 1000 classes), you will need a GPU with more than (not including) 16GB memory. This codebase is tested on a GPU with 24GB memory. To evaluate other datasets (with less than a few hundred classes), a GPU with 16GB memory will work fine.
The code is tested on PyTorch 1.7.1.
We suggest downloading all datasets to a root directory (${data_root}
), and renaming the directory of each dataset as suggested in ${ID_to_DIRNAME}
in ./data/datautils.py
. This would allow you to evaluate multiple datasets within the same run.
If this is not feasible, you could evaluate different datasets separately, and change the ${data_root}
accordingly in the bash script.
For out-of-distribution generalization, we consider 5 datasets:
For cross-datasets generalization, we consider 10 datasets:
For cross-dataset generalization, we adopt the same train/val/test splits as CoOp. Please refer to this page, and look for download links of split_zhou_${dataset_name}.json
, and put the json files under ./data/data_splits/
.
We provide three bash scripts under ./scripts
. You can modify the paths and other args in the scripts.
An example to run TPT with CoOp initialization on out-of-distribution datasets:
bash ./scripts/test_coop.sh I/A/V/R/K.
The command line arg ${testsets}
can be multiple test datasets split by "/" (, which are stored under the same root dir ${data_root}
).
Note that for simplicity, we use set_id
to denote different datasets. A complete list of set_id
can be found in ${ID_to_DIRNAME}
in ./data/datautils.py
.
Method | ImageNet(IN) | IN-A | IN-V2 | IN-R | IN-Sketch | Average | OOD Average |
---|---|---|---|---|---|---|---|
CLIP-RN50 | 58.16 | 21.83 | 51.41 | 56.15 | 33.37 | 44.18 | 40.69 |
Ensembled prompt | 59.81 | 23.24 | 52.91 | 60.72 | 35.48 | 46.43 | 43.09 |
CoOp | 63.33 | 23.06 | 55.40 | 56.60 | 34.67 | 46.61 | 42.43 |
CoCoOp | 62.81 | 23.32 | 55.72 | 57.74 | 34.48 | 46.81 | 42.82 |
TPT (ours) | 60.74 | 26.67 | 54.7 | 59.11 | 35.09 | 47.26 | 43.89 |
TPT + CoOp | 64.73 | 30.32 | 57.83 | 58.99 | 35.86 | 49.55 | 45.75 |
TPT + CoCoOp | 62.93 | 27.40 | 56.60 | 59.88 | 35.43 | 48.45 | 44.83 |
In each matrix
Cross-dataset improvement normalized by the zero-shot baseline performance.
If you find our code useful or our work relevant, please consider citing:
@inproceedings{shu2022tpt,
author = {Manli, Shu and Weili, Nie and De-An, Huang and Zhiding, Yu and Tom, Goldstein and Anima, Anandkumar and Chaowei, Xiao},
title = {Test-Time Prompt Tuning for Zero-shot Generalization in Vision-Language Models},
booktitle = {NeurIPS},
year = {2022},
}
We thank the authors of CoOp/CoCoOp for their open-source implementation and instructions on data preparation.