- Install tensorflow and other
requirements.txt
- Get DataSet
- run
python train.py
(takes a while, depends on your machine) - run
python detect.py --image my_image.jpg
./data/data_generator.py
- generates train/val data from WIDER FACE./model/model.py
- generates TF model./model/loss.py
- definition of the loss function for training./model/validation.py
- definition of the validation for training./config.py
- stores network/training/validation config for network./detect.py
- runs model against given image and generates output image./draw_boxes
- helper function for./detect.py
, draws boxes on cv2 img./print_model.py
- prints current model structure./train.py
- starts training our model and create weights base on training results and validation function
We want to use WIDER FACE dataset. It contain over 32k images with almost 400k faces and is publicly available on http://shuoyang1213.me/WIDERFACE/
Please put all the data into ./data
folder.
Data structure is described in ./data/wider_face_split/readme.txt
. We only need to use boxes annotations but there is more data available if someone wants to use it.
@config_path - path to data/wider_face_split/wider_face_train_bbx_gt.txt
file (defined in cfg.TRAIN.ANNOTATION_PATH
)
@file_path - path to folder with images (defined in cgf.TRAIN.DATA_PATH
)
__init__(file_path, config_path, debug=False)
loops over all images in txt file (base onconfig_path
) and stores them inside generator to be retrieved by__getitem__
__len__()
unsurprisingly returns length of our data (exactly number of batches `data/batch_size)__getitem__(idx)
- returns data for givenidx
, data returned asArray<imagePath>, Array<h, w, yc, xc, class>
create_model(trainable=False)
- creates model base on definition, if you want model to be fully trainable (not only output layers) then settrainable
to beTrue
loss(y_true, y_pred)
- returns value of loss function for current prediction (y_true
is a box from dataset,y_pred
is a output from NN)get_box_highest_percentage(arr)
- helper function forloss
to get best box match
on_epoch_end(self, epoch, logs)
- calculatesIoU
andmse
for validation setget_box_highets_percentage(self, mask)
- helper function, you can ignore it
Just a config, there is couple of important things in it:
ALPHA
- mobilenet's "alpha" size, higher value means more complex network (slower, more precise)GRID_SIZE
- output grid size, 7 is a good value for low ALPHA but you might want to set it to higher value for larger ALPHAs and add UpSample layer to model.pyINPUT_SIZE
- value should be adjusted base on initial network used (224 for MobileNetV2, but check input size if you changing model)
Inside TRAIN
prefix there is couple training hyperparameters you can adjust for training
You have to first train model to get at least one model-0.xx.h5
weights file
Usage:
# basic usage
python detect.py --image path_to_my_image.jpg
# use different trained weights and output path
python detect.py --image path_to_my_image.jpg --weights model-0.64.h5 --output output_path.jpg
There is no parameters for it but you might want to read that file. It's running base on config.py
and other files already described. If you want to train your model from specific point then uncomment IF TRAINABLE
and add weights file.
After running training script will generate ./logs/fit/**
files. You can use Tensorboard for visualise training
tensorboard --logdir logs/fit