Skip to content

CUB 200 Image Classification

Grant Van Horn edited this page Feb 27, 2017 · 18 revisions

In this tutorial we are going to train a classification model using the CUB-200-2011 dataset. This dataset contains 200 species of birds, each with roughly 30 training images and 30 testing images, and has become a staple for testing new ideas for fine-grained visual classification.

Download the dataset

You can find the dataset website here. The dataset files are relatively small (about 1.3 GB when untared) and should easily fit on your machine.

$ wget http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz
$ tar -xzf CUB_200_2011.tgz

Create the tfrecord files

We will use the tfrecords repo to create the tfrecord files that we can use to train and test the model. You'll need to clone that repo:

$ cd ~/code
$ git clone https://github.com/visipedia/tfrecords.git

Before we can call the create() method in the create_tfrecords.py we will need to format the CUB data. You can find a script for doing this formatting here. Fire up an ipython terminal and %cpaste that script into the terminal. Now we can format the CUB dataset:

# Change these paths to match where you downloaded the CUB dataset 
cub_dataset_dir = "/media/drive2/datasets/CUB_200_2011"
cub_image_dir = "/media/drive2/datasets/CUB_200_2011/images"

# we need to create a file containing the size of each image in the dataset. 
# you only need to do this once. scipy is required for this method. 
# Alternatively, you can create this file yourself. 
# Each line should have <image_id> <width> <height>
create_image_sizes_file(cub_dataset_dir, cub_image_dir)

train, test = format_dataset(cub_dataset_dir, cub_image_dir)
train, val = create_validation_split(train, images_per_class=5, shuffle=True)

We have created three arrays holding train, validation and test data. The number of elements in each array should be:

  • Number of train images: 4994
  • Number of validation images: 1000
  • Number of test images: 5794

We can now pass these arrays to the create() method:

from create_tfrecords import create

# Change this path
dataset_dir = "/media/drive2/tensorflow_datasets/cub/with_1k_val_split/"

train_errors = create(dataset=train, dataset_name="train", output_directory=dataset_dir,
                      num_shards=10, num_threads=2, shuffle=True)

val_errors = create(dataset=val, dataset_name="val", output_directory=dataset_dir,
                    num_shards=4, num_threads=2, shuffle=True)

test_errors = create(dataset=test, dataset_name="test", output_directory=dataset_dir,
                     num_shards=10, num_threads=2, shuffle=True)

We now have a dataset directory containing tfrecord files prefixed with either train, val or test that we can use to train and test a model.

Experiment Directory

We'll store all the experiment files in a directory called cub_image_experiment/. Create the following directory structure:

  • cub_image_experiment/
    • logdir/
      • val_summaries/
      • test_summaries/
      • finetune/
        • val_summaries/

Configuration Files

We'll need two configuration files:

  • config_train.yaml: This will contain all the configurations for image augmentation, the optimizer, the learning rate, model regularization, snapshotting the model and a few other things.
  • config_test.yaml: This will contain only the necessary configurations to test a model. It is essentially a subset of the training configurations with the image augmentations turned off.

The train configuration file can be found here, and the test can be found here. You should copy these two configuration files and put them into the cub_image_experiment directory. I have renamed them config_train.yaml and config_test.yaml respectively.

Data visualization

It is always a good idea to visualize the inputs to the network before starting a long training process. This helps identify problems with the data early on.

Install the tf_classification repo if you haven't done so already.

Warm up training

Full training

Testing

Further Reading

Clone this wiki locally