Skip to content

Latest commit

 

History

History
executable file
·
435 lines (355 loc) · 15.8 KB

README.md

File metadata and controls

executable file
·
435 lines (355 loc) · 15.8 KB

Language grade: Python Star Fork License

Open In Colab

🔥 ESRGAN (Enhanced Super-Resolution Generative Adversarial Networks, published in ECCV 2018) implemented in Tensorflow 2.0+. This is an unofficial implementation. 🔥

ESRGAN introduce the Residual-in-Residual Dense Block (RRDB) without batch normalization as the basic network building unit, the idea from relativistic GAN to let the discriminator predict relative realness, and the perceptual loss by using the features before activation. Benefiting from these improvements, the proposed ESRGAN achieves consistently better visual quality with more realistic and natural textures than SRGAN and won the first place in the PIRM2018-SR Challenge.

Original Paper:   Arxiv   ECCV2018

Offical Implementation:   PyTorch

:: Results from this reporepository. ::


Contents

📑


Installation

🍕

Create a new python virtual environment by Anaconda or just use pip in your python environment and then clone this repository as following.

Clone this repo

git clone https://github.com/peteryuX/esrgan-tf2.git
cd esrgan-tf2

Conda

conda env create -f environment.yml
conda activate esrgan-tf2

Pip

pip install -r requirements.txt

Data Preparing

🍺

All datasets used in this repository follow the official implement as same as possible. This code focus on implementation of x4 version.

Training Dataset

Step 1: Download the DIV2K GT images and corresponding LR images from the download links bellow.

Dataset Name Link
Ground-Truth DIV2K_train_HR
LRx4 (MATLAB bicubic) DIV2K_train_LR_bicubic_X4

Note: If you want to dowsample your traning data as LR images by yourself, you can use the imresize_np() wich is numpy implementation or MATLAB resize.

Step 2: Extract them into ./data/DIV2K/. The directory structure should be like bellow.

./data/DIV2K/
    -> DIV2K_valid_HR/
        -> 0001.png
        -> 0002.png
        -> ...
    -> DIV2K_train_LR_bicubic/
        -> X4/
            -> 0001x4.png
            -> 0002x4.png

Step 3: Rename and Crop to sub-images with the script bellow. Modify these scripts if you need other setting.

# rename image file in LR folder `DIV2K_train_LR_bicubic/*'.
python data/rename.py

# extract sub-images from HR folder and LR folder.
python data/extract_subimages.py

Step 4: Convert the sub-images to tfrecord file with the the script bellow.

# Binary Image (recommend): convert slow, but loading faster when traning.
python data/convert_train_tfrecord.py --output_path="./data/DIV2K800_sub_bin.tfrecord" --is_binary=True
# or
# Online Image Loading: convert fast, but loading slower when training.
python data/convert_train_tfrecord.py --output_path="./data/DIV2K800_sub.tfrecord" --is_binary=False

Note:

  • You can run python ./dataset_checker.py to check if the dataloader work.

Testing Dataset

Step 1: Download the common image SR datasets from the download links bellow. You only need Set5 and Set14 in the default setting ./configs/*.yaml.

Dataset Name Short Description Link
Set5 Set5 test dataset Google Drive
Set14 Set14 test dataset Google Drive
BSDS100 A subset (test) of BSD500 for testing Google Drive
Urban100 100 building images for testing (regular structures) Google Drive
Manga109 109 images of Japanese manga for testing Google Drive
Historical 10 gray LR images without the ground-truth Google Drive

Step 2: Extract them into ./data/. The directory structure should be like bellow. The directory structure should be like bellow.

./data/
    -> Set5/
        -> baby.png
        -> bird.png
        -> ...
    -> Set14/
        -> ...

Training and Testing

🍭

Config File

You can modify your own dataset path or other settings of model in ./configs/*.yaml for training and testing, which like below.

# general setting
batch_size: 16
input_size: 32
gt_size: 128
ch_size: 3
scale: 4
sub_name: 'esrgan'
pretrain_name: 'psnr_pretrain'

# generator setting
network_G:
    nf: 64
    nb: 23
# discriminator setting
network_D:
    nf: 64

# dataset setting
train_dataset:
    path: './data/DIV2K800_sub_bin.tfrecord'
    num_samples: 32208
    using_bin: True
    using_flip: True
    using_rot: True
test_dataset:
    set5_path: './data/Set5'
    set14_path: './data/Set14'

# training setting
niter: 400000

lr_G: !!float 1e-4
lr_D: !!float 1e-4
lr_steps: [50000, 100000, 200000, 300000]
lr_rate: 0.5

adam_beta1_G: 0.9
adam_beta2_G: 0.99
adam_beta1_D: 0.9
adam_beta2_D: 0.99

w_pixel: !!float 1e-2
pixel_criterion: l1

w_feature: 1.0
feature_criterion: l1

w_gan: !!float 5e-3
gan_type: ragan  # gan | ragan

save_steps: 5000

Note:

  • The sub_name is the name of outputs directory used in checkpoints and logs folder. (make sure of setting it unique to other models)
  • The using_bin is used to choose the type of training data, which should be according to the data type you created in the Data-Preparing.
  • The w_pixel/w_feature/w_gan is the combined weight of pixel/feature/gan loss.
  • The save_steps is the number interval steps of saving checkpoint file.

Training

Pretrain PSNR

Pretrain the PSNR RDDB model by yourself, or dowload it from BenchmarkModels.

python train_psnr.py --cfg_path="./configs/psnr.yaml" --gpu=0

ESRGAN

Train the ESRGAN model with the pretrain PSNR model.

python train_esrgan.py --cfg_path="./configs/esrgan.yaml" --gpu=0

Note:

  • Make sure you have the pretrain PSNR model before train ESRGAN model. (Pretrain model checkpoint should be located at ./checkpoints for restoring)
  • The --gpu is used to choose the id of your avaliable GPU devices with CUDA_VISIBLE_DEVICES system varaible.
  • You can visualize the learning rate scheduling by running "python ./modules/lr_scheduler.py".

Testing

You can download my trained models for testing from Models without training it yourself. And, evaluate the models you got with the corresponding cfg file on the testing dataset. The visualizations results would be saved into ./results/.

# Test ESRGAN model
python test.py --cfg_path="./configs/esrgan.yaml"
# or
# PSNR pretrain model
python test.py --cfg_path="./configs/psnr.yaml"

SR Input Image

You can upsample your image by the SR model. For example, upsample the image from ./data/baboon.png as following.

python test.py --cfg_path="./configs/esrgan.yaml" --img_path="./data/baboon.png"
# or
# PSNR pretrain model
python test.py --cfg_path="./configs/psnr.yaml" --img_path="./data/baboon.png"

Network Interpolation

Produce the compare results between network interpolation and image interpolation as same as original paper.

python net_interp.py --cfg_path1="./configs/psnr.yaml" --cfg_path2="./configs/esrgan.yaml" --img_path="./data/PIPRM_3_crop.png" --save_image=True --save_ckpt=True

Note:

  • --save_image means save the compare results into ./results_interp.
  • --save_ckpt means save all the interpolation ckpt files into ./results_interp.

Benchmark and Visualization

Verification results (PSNR/SSIM) and visiualization results.

Set5

Image Name Bicubic PSNR (pretrain) ESRGAN Ground Truth
baby
31.96 / 0.85 33.86 / 0.89 31.36 / 0.83 -
bird
30.27 / 0.87 35.00 / 0.94 32.22 / 0.90 -
butterfly
22.25 / 0.72 28.56 / 0.92 26.66 / 0.88 -
head
32.01 / 0.76 33.18 / 0.80 30.19 / 0.70 -
woman
26.44 / 0.83 30.42 / 0.92 28.50 / 0.88 -

Set14 (Partial)

Image Name Bicubic PSNR (pretrain) ESRGAN Ground Truth
baboon
22.06 / 0.45 22.77 / 0.54 20.73 / 0.44 -
comic
21.69 / 0.59 23.46 / 0.74 21.08 / 0.64 -
lenna
29.67 / 0.80 32.06 / 0.85 28.96 / 0.80 -
monarch
27.60 / 0.88 33.27 / 0.94 31.49 / 0.92 -
zebra
24.15 / 0.68 27.29 / 0.78 24.86 / 0.67 -

Note:

  • The baseline Bicubic resizing method can be find at imresize_np().
  • All the PSNR and SSIM results are calculated on Y channel of YCbCr.
  • All results trained on DIV2K.

Network Interpolation (on ./data/PIPRM_3_crop.png)

weight interpolation

image interpolation

(ESRGAN <-> PSNR, alpha=[1., 0.8, 0.6, 0.4, 0.2, 0.])


Models

🍩

Model Name Download Link
PSNR GoogleDrive
ESRGAN GoogleDrive
PSNR (inference) GoogleDrive
ESRGAN (inference) GoogleDrive

Note:

  • After dowloading these models, extract them into ./checkpoints for restoring.
  • The inference version was saved without any tranning operator, which is smaller than the original version. However, if you want to finetune, the orginal version is more suitable.
  • All training settings of the models can be found in the corresponding ./configs/*.yaml files.
  • Based on the property of the training dataset, all the pre-trained models can only be used for non-commercial applications.

References

🍔

Thanks for these source codes porviding me with knowledges to complete this repository.