- Project Description
- Project Setup
- How To Train Model In this Project
- How To Restore Images In this Project
- How To Check Model Architecture
The GAN (Generative Adversarial Netwrok) algorithm is a class of machine learning frameworks designed by Ian Goodfellow and his colleagues in June 2014. It's based on "Game Theory", to make two neural networks contest with each other.
This project will restore image using GAN model, and here is how it works:
- Model setup:
- Train model:
- Yields batches of images from
training_data
. Thetraining_data
's shape is(image_count, image_width, image_hight, image_channel)
- Put the random mask over data (each picture)
- Customize the loss function of discriminator and generator
- Gradient descent with respect to variables of discriminator and generator
- using
tensorflow.GradientTape
to implement gradient descent
- using
- Plot training progress bar in terminal
- using
rich
packages of Python to plotepochs
,completeness
,generator loss
anddiscriminator loss
- using
- Save model structure and parameters when it finish model training
- Yields batches of images from
- Image Restoration
- Load trained model
- Get any image with mask fits
training_data
's shape, e.g.(image_count, image_width, image_hight, image_channel)
- Restore image
To avoid TensorFlow version conflicts, the project use pipenv (Python vitural environment) to install Python packages.
Notice: Before executing the following command, please refer to TensorFlow Installation Source and modify the TensorFlow version in
Pipfile
andPipfile.lock
(or modifyPipfile
and removePipfile.lock
)
pip install pipenv
pipenv shell
pipenv install
In model training stage, you can modify the model architecture or the hyperparameter in src/model/GAN.py
like epochs, learning_rate, learning_rate_decay, etc.
python src/train.py
You can use model you trained or apply the following model to restore images:
- The example model
generator_example.h5
atsrc/model/trained_model/
- Other trained model on Google Drive
python src/predict.py
You can modify model_path
in src/watch_model_architecture.py
to watch any model you want
python src/watch_model_architecture.py