Skip to content

The cifar10 classification project completed by tensorflow, including complete training, prediction, visualization, independent of each module of the project, and convenient expansion.

Notifications You must be signed in to change notification settings

ranjiewwen/TF_cifar10

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

23 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Cifar10 Tensorflow Project

Get Started

  • environment: tensorflow-gpu1.8+cude9.0
  • datasets from kaggle : CIFAR-10 - Object Recognition in Images, first you can download the train and test dataset.
  • then use the utils/get_data_list.py and utils/get_dataset_mean.py scripts to generate train.txt and val.txt.

How to Learn this project

  • one step: you can modify trian parameters in config/cifar10_config.json.
  • two step: you can learn how load datasets before training from src/datasets/cifar10_dataloader.py.
  • three step: you can learn how to write network from src/models/layers and src/models/simple_model.py, you can easily create you own model.
  • four step: you should finish trian scripts tools/train_cifar10.py, in this process you will finish loss function and metric funtion:src/loss/cross_entropy.py and src/metrics/acc_metric.py; in this scripts tools/train_cifar10.py, we will first create graph and then run session. at the same time, we will record train models and use tensorboard to visual loss and accuracy in experiments/chekpoint and experiment/summary folder.
  • five step: you can run train scripts:tools/train_cifar10.py.
  • six step: when you get train model, you can predict image and get class name in demo/prdict.py.
  • seven step: you can also get some extra information from demo/visual.py, such as weights or visual feature map.
  • other: you can fimilar how to use some tool function in tools/utils.py.

The optimization process

  • The detailed information you can get from there.
  • run scripts tools/trian_cifar10.py include adjust lr , add data augmentation ,add dropout ,weight decay,stack 3*3 conv training tricks. you can learn how train model acc from 70%+ to 91+%, while add model depth through conv4_1 and conv4_2 it can not imporve val acc.
  • run scripts tools/trian_cifar10_v2.py include add batch_norm, we can see it make the training more unstable, maybe it not imporve val acc, while stack 3*3 conv it can improve val acc remarkable.
  • run scripts tools/fintune_cifar10.py. it frist load imagenet pretrain weights and then finetune resnet50.

Reference

finetune

tiny-imagenet

mnist

other

About

The cifar10 classification project completed by tensorflow, including complete training, prediction, visualization, independent of each module of the project, and convenient expansion.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages