Skip to content

Zero-shot sketch-based image retrieval using a domain adversarial neural network

Notifications You must be signed in to change notification settings

ashok-arjun/Zero-Shot-Sketch-Based-Image-Retrieval

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Zero Shot Sketch-Based Image Retrieval

The problem of retrieving images from a large-database using ambiguous sketches has been addressed. This problem has been addressed in the zero-shot scenario, where the test sketches/images are from unseen classes and the deep feature extractor's output embedding distances of the sketch and the image have been used to retrieve the top k closest images from the image database of the unseen classes.

The standard triplet loss has been used, along with a domain loss, which is trained to differentiate between sketches and images.

The embedding is passed in to gradient reversal layer 1, and then into a domain-classifier network, and then into the domain loss. The gradient reversal layer acts as an identity layer in the forward pass, and multiplies the gradient by -1 in the backward pass which results in the main network learning a domain-agnostic representation i.e. to fool the domain-classifier network.

Architecture Overview

† Diagram created by Arjun Ashok using app.diagrams.net

Results on unseen(zero-shot) classes

The below table presents a few qualitative results of our model on unseen test classes. The table presents the top 5 results from left to right.

Query Sketch 1
Retrieved Image 1 Retrieved Image 2 Retrieved Image 3 Retrieved Image 4 Retrieved Image 5
Query Sketch 2
Retrieved Image 1 Retrieved Image 2 Retrieved Image 3 Retrieved Image 4 Retrieved Image 5

Instructions

Installation

Please execute the following command to install the required libraries:

pip install -r requirements.txt
Data

Execute bash download_data.sh

Training

The file train.py can be invoked with the following arguments:

usage: train.py [-h] --data_dir DATA_DIR --batch_size BATCH_SIZE
                --checkpoint_dir CHECKPOINT_DIR --epochs EPOCHS
                [--domain_loss_ratio DOMAIN_LOSS_RATIO]
                [--triplet_loss_ratio TRIPLET_LOSS_RATIO]
                [--grl_threshold_epoch GRL_THRESHOLD_EPOCH]
                [--print_every PRINT_EVERY]

Training of SBIR

optional arguments:
  -h, --help            show this help message and exit
  --data_dir DATA_DIR   Data directory path. Directory should contain two
                        folders - sketches and photos, along with 2 .txt files
                        for the labels
  --batch_size BATCH_SIZE
                        Batch size to process the train sketches/photos
  --checkpoint_dir CHECKPOINT_DIR
                        Directory to save checkpoints
  --epochs EPOCHS       Number of epochs
  --domain_loss_ratio DOMAIN_LOSS_RATIO
                        Domain loss weight
  --triplet_loss_ratio TRIPLET_LOSS_RATIO
                        Triplet loss weight
  --grl_threshold_epoch GRL_THRESHOLD_EPOCH
                        Threshold epoch for GRL lambda
  --print_every PRINT_EVERY
                        Logging interval in iterations

It is advised to use a GPU for training. The code automatically detects and uses a GPU, if available.

Inference

The file evaluate.py can be invoked with the following args:

usage: evaluate.py [-h] [--model MODEL] --data DATA [--num_images NUM_IMAGES]
                   [--num_sketches NUM_SKETCHES] [--batch_size BATCH_SIZE]
                   [--output_dir OUTPUT_DIR]

Evaluation of SBIR

arguments:
  -h, --help            show this help message and exit
  --model MODEL         Model checkpoint path
  --data DATA           Data directory path. Directory should contain two
                        folders - sketches and photos, along with 2 .txt files
                        for the labels
  --num_images NUM_IMAGES
                        Number of random images to output for every
                        sketch
  --num_sketches NUM_SKETCHES
                        Number of random sketches to output
  --batch_size BATCH_SIZE
                        Batch size to process the test sketches/photos
  --output_dir OUTPUT_DIR
                        Directory to save output sketch and images

It is advised to use a GPU for evaluation. The code automatically detects and uses a GPU, if available.

References

  1. Ganin, Yaroslav et al. "Domain-Adversarial Training Of Neural Networks". Journal of Machine Learning Research, 2016, pp. 1-35, url:http://jmlr.org/papers/v17/15-239.html