This repository contains Python scripts for training and evaluating an image classification model based on the VGG-16 architecture using PyTorch. The trained model is capable of classifying images into two categories: dogs and cats. Additionally, there is an API script that implements the trained model and allows users to classify multiple images as either dogs or cats.
- Prerequisites
- Getting Started
- Project Structure
- Code Structure
- Results
- API Usage
- Contributing
- License
- Acknowledgments
Before running the code, make sure you have the following dependencies installed:
- Python 3.x
- PyTorch
- torchvision
- Matplotlib
- Scikit-learn
- FastAPI
- Uvicorn
- Clone the repository:
git clone https://github.com/anhphan2705/Image-Classification-Dog-Cat.git
- Install the required dependencies:
pip install torch torchvision matplotlib scikit-learn fastapi uvicorn
- Download pre-trained model (optional):
-
Here is a link to my trained model with a classification report available to download. It reported a 99.6% accuracy for my test file.
https://www.dropbox.com/s/nxllvz36o241dal/VGG-Train-9960.zip?dl=0
- Prepare the Data:
-
Place your training, validation, and test datasets in separate directories (
train
,val
,test
) inside thedata
directory as shown below:Image-Classification-VGG-Dog-Cat/ cat_dog_train_model.py cat-dog-classifier-api.py pre-trained-model.pth (optional) output/ data/ train/ class1/ image1.jpg image2.jpg ... class2/ image1.jpg image2.jpg ... ... val/ class1/ image1.jpg image2.jpg ... class2/ image1.jpg image2.jpg ... ... test/ class1/ image1.jpg image2.jpg ... class2/ image1.jpg image2.jpg ... ...
-
If you want to load a pre-trained model from your local computer, add the
model_dir
argument to theget_vgg16_pretrained_model()
function in thecat_dog_train_model.py
script.
- Training and Evaluation:
-
Open the
cat_dog_train_model.py
file and modify the necessary parameters such as file directories, output directory, and your data files name.file_dir = './data' out_model_dir = './output/VGG16_trained.pth' out_plot_dir = './output/epoch_progress.jpg' out_report_dir = './output/classification_report.txt' TRAIN = 'train' VAL = 'val' TEST = 'test'
-
Run the script:
python ./cat_dog_train_model.py
-
The script will train the VGG-16 model on the training dataset, evaluate its performance on the validation dataset, and save the trained model for future use.
The project is organized as follows:
cat_dog_train_model.py
: The main script to train and evaluate the VGG-16 model.cat-dog-classifier-api.py
: The script for implementing the trained model as an API using FastAPI and Uvicorn.data/
: Directory to store the training, validation, and test datasets.output/
: Directory to save the trained model and evaluation results.
The code follows the following structure:
- Data Loading and Transformation:
get_data(file_dir)
: Loads and transforms the data using PyTorch'sImageFolder
andDataLoader
.
- Model Creation and Modification:
get_vgg16_pretrained_model(model_dir='', weights=models.vgg16_bn(pretrained=True).state_dict(), len_target=1000)
: Retrieves the VGG-16 pre-trained model and modifies its classifier for the desired number of output classes.
- Evaluation:
eval_model(vgg, criterion, dataset='val')
: Evaluates the model's performance on the specified dataset.get_epoch_progress_graph(accuracy_train, loss_train, accuracy_val, loss_val, save_dir=out_plot_dir)
: Plots the progress of accuracy and loss during training epochs.get_classification_report(truth_values, pred_values)
: Generate a classification report and confusion matrix for the model predictions.save_classification_report(truth_values, pred_values, out_report_dir)
: Save the report at a preset directory.
- Training:
train_model(vgg, criterion, optimizer, scheduler, num_epochs=10)
: Trains the model using the training dataset and evaluates its performance on the validation dataset.
The code follows the following structure:
- API Routes:
/
: Serves the root route and displays a welcome message with a link to the API documentation./dog-cat-classification
: API endpoint to classify multiple images as either dogs or cats using the fine-tuned VGG-16 model.
- Image Processing Functions:
convert_byte_to_arr(byte_image)
: Convert an image in byte format to a PIL Image object (RGB format).convert_arr_to_byte(arr_image)
: Convert a numpy array image (RGB format) to byte format (JPEG).multiple_to_one(images)
: Combine multiple images horizontally into a single image.assign_image_label(images, labels, font="arial.ttf", font_size=25)
: Add labels to the input images.get_data(np_images)
: Prepare the list of numpy array images for classification.get_vgg16_pretrained_model(model_dir=MODEL_DIRECTORY, weights=models.VGG16_BN_Weights.DEFAULT)
: Retrieve the VGG-16 pre-trained model and modify the classifier with a fine-tuned one.get_prediction(model, images)
: Perform image classification using the provided model.
- API Endpoints:
welcome_page()
: Serves the root route ("/") and displays a welcome message with a link to the API documentation.dog_cat_classification(in_images: list[UploadFile])
: API endpoint to classify multiple images as either dogs or cats using a fine-tuned VGG-16 model.
- The trained VGG-16 model will be saved in the
output
directory asVGG16_trained.pth
. You can use this model for inference on new images or further fine-tuning if needed. - There will also be a plotted chart of all the epoch stats saved in
output
asepoch_progress.jpg
. - Finally, a full classification report of the model when testing the
test
file will also be saved inoutput
asclassification_report.txt
.
- Create a classification model:
- You can train a model by using the
cat_dog_train_model.py
script. See Getting Started step 4 and 5. - Or you can download my model that I trained with accuracy of 99.6% here
- Follow the suggested folder directory
-
Place the model somewhere, preferably as shown in Getting Started step 4.
-
Modify the
MODEL_DIRECTORY
constant in thecat-dog-classifier-api.py
script accordingly at line 15MODEL_DIRECTORY = './output/VGG16_trained_9960.pth'
- Adjust variables:
-
You can also adjust some varibles as you prefer for the API response at line 16 and 17. These font setting are for the labels that will be print on the output images
FONT = "arial.ttf" FONT_SIZE = 25
- Run the API:
uvicorn cat-dog-classifier-api:app --host 0.0.0.0 --port 8000
- Access the API documentation:
Go to http://localhost:8000/docs
in your web browser to access the API documentation and interact with the /dog-cat-classification
endpoint.
Contributions are welcome! If you have any suggestions or improvements for this code, please feel free to submit a pull request.
This project is licensed under the MIT License.
- The VGG-16 model implementation is based on the torchvision library in PyTorch.
- The dataset loading and transformation code is adapted from PyTorch's official documentation.
- The API implementation is based on FastAPI and Uvicorn.