Skip to content

A PyTorch implementation of ClipPrompt based on CVPR 2023 paper "CLIP for All Things Zero-Shot Sketch-Based Image Retrieval, Fine-Grained or Not"

Notifications You must be signed in to change notification settings

leftthomas/ClipPrompt

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

21 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ClipPrompt

A PyTorch implementation of ClipPrompt based on CVPR 2023 paper CLIP for All Things Zero-Shot Sketch-Based Image Retrieval, Fine-Grained or Not.

Network Architecture

Requirements

conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
conda install -c conda-forge torchmetrics
pip install git+https://github.com/openai/CLIP.git

Dataset

Sketchy Extended and TU-Berlin Extended datasets are used in this repo, you could download these datasets from official websites, or download them from Google Drive. The data directory structure is shown as follows:

├──sketchy
  ├── train
      ├── sketch
          ├── airplane
              ├── n02691156_58-1.jpg
              └── ...
          ...
      ├── photo
          same structure as sketch
  ├── val
     same structure as train
     ...
├──tuberlin
  same structure as sketchy
  ...

Usage

To train a model on Sketchy Extended dataset, run:

python main.py --mode train --data_name sketchy

To test a model on Sketchy Extended dataset, run:

python main.py --mode test --data_name sketchy --query_name <query image path>

common arguments:

--data_root                   Datasets root path [default value is '/home/data']
--data_name                   Dataset name [default value is 'sketchy'](choices=['sketchy', 'tuberlin'])
--prompt_num                  Number of prompt embedding [default value is 3]
--save_root                   Result saved root path [default value is 'result']
--mode                        Mode of the script [default value is 'train'](choices=['train', 'test'])

train arguments:

--batch_size                  Number of images in each mini-batch [default value is 64]
--epochs                      Number of epochs over the model to train [default value is 60]
--triplet_margin              Margin of triplet loss [default value is 0.3]
--encoder_lr                  Learning rate of encoder [default value is 1e-4]
--prompt_lr                   Learning rate of prompt embedding [default value is 1e-3]
--cls_weight                  Weight of classification loss [default value is 0.5]
--seed                        Random seed (-1 for no manual seed) [default value is -1]

test arguments:

--query_name                  Query image path [default value is '/home/data/sketchy/val/sketch/cow/n01887787_591-14.jpg']
--retrieval_num               Number of retrieved images [default value is 8]

Benchmarks

The models are trained on one NVIDIA GeForce RTX 3090 (24G) GPU. seed is 42, prompt_lr is 1e-3 and distance function is 1.0 - F.cosine_similarity(x, y), the other hyperparameters are the default values.

Dataset Prompt Num mAP@200 mAP@all P@100 P@200 Download
Sketchy Extended 3 71.9 64.3 70.8 68.1 MEGA
TU-Berlin Extended 3 75.3 66.0 73.9 69.7 MEGA

Results

vis

About

A PyTorch implementation of ClipPrompt based on CVPR 2023 paper "CLIP for All Things Zero-Shot Sketch-Based Image Retrieval, Fine-Grained or Not"

Topics

Resources

Stars

Watchers

Forks

Languages