diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..2757c63 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [2021] [MANet Authors] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index fdf87bd..8763a3d 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,113 @@ -# Mutual Affine Network for Spatially Variant Kernel Estimation in Blind Image Super-Resolution (MANet, ICCV2021) -Official PyTorch code for Mutual Affine Network for Spatially Variant Kernel Estimation in Blind Image Super-Resolution (MANet, ICCV2021) -# Stay tuned! The code is coming before 18th August. +# Mutual Affine Network for Spatially Variant Kernel Estimation in Blind Image Super-Resolution (MANet) + +This repository is the official PyTorch implementation of Mutual Affine Network for ***Spatially Variant*** Kernel Estimation in Blind Image Super-Resolution +([arxiv](https://arxiv.org/pdf/2108.05302.pdf), [supplementary](https://github.com/JingyunLiang/MANet/releases/tag/v0.0)). + + :rocket: :rocket: :rocket: **News**: + - Aug. 17, 2021: See our recent work for flow-based image SR: [Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling (HCFlow), ICCV2021](https://github.com/JingyunLiang/HCFlow) + - Aug. 17, 2021: See our recent work for real-world image SR: [Designing a Practical Degradation Model for Deep Blind Image Super-Resolution (BSRGAN), ICCV2021](https://github.com/cszn/BSRGAN) + - Aug. 17, 2021: See our previous flow-based work: *[Flow-based Kernel Prior with Application to Blind Super-Resolution (FKP), CVPR2021](https://github.com/JingyunLiang/FKP).* + + --- + +> Existing blind image super-resolution (SR) methods mostly assume blur kernels are spatially invariant across the whole image. However, such an assumption is rarely applicable for real images whose blur kernels are usually spatially variant due to factors such as object motion and out-of-focus. Hence, existing blind SR methods would inevitably give rise to poor performance in real applications. To address this issue, this paper proposes a mutual affine network (MANet) for spatially variant kernel estimation. Specifically, MANet has two distinctive features. First, it has a moderate receptive field so as to keep the locality of degradation. Second, it involves a new mutual affine convolution (MAConv) layer that enhances feature expressiveness without increasing receptive field, model size and computation burden. This is made possible through exploiting channel interdependence, which applies each channel split with an affine transformation module whose input are the rest channel splits. Extensive experiments on synthetic and real images show that the proposed MANet not only performs favorably for both spatially variant and invariant kernel estimation, but also leads to state-of-the-art blind SR performance when combined with non-blind SR methods. +>
+ > +
+ + + +## Requirements +- Python 3.6, PyTorch >= 1.6 +- Requirements: opencv-python +- Platforms: Ubuntu 16.04, cuda-10.0 & cuDNN v-7.5 + +Note: this repository is based on [BasicSR](https://github.com/xinntao/BasicSR#memo-codebase-designs-and-conventions). Please refer to their repository for a better understanding of the code framework. + + +## Quick Run +Download `stage3_MANet+RRDB_x4.pth` from [release](https://github.com/JingyunLiang/MANet/releases/tag/v0.0) and put it in `./pretrained_models`. Then, run this command: +```bash +cd codes +python test.py --opt options/test/test_stage3.yml +``` +--- + +## Data Preparation +To prepare data, put training and testing sets in `./datasets` as `./datasets/DIV2K/HR/0801.png`. Commonly used datasets can be downloaded [here](https://github.com/xinntao/BasicSR/blob/master/docs/DatasetPreparation.md#common-image-sr-datasets). + + +## Training + +Step1: to train MANet, run this command: + +```bash +python train.py --opt options/train/train_stage1.yml +``` + +Step2: to train non-blind RRDB, run this command: + +```bash +python train.py --opt options/train/train_stage2.yml +``` + +Step3: to fine-tune RRDB with MANet, run this command: + +```bash +python train.py --opt options/train/train_stage3.yml +``` + +All trained models can be downloaded from [release](https://github.com/JingyunLiang/MANet/releases/tag/v0.0). For testing, downloading stage3 models is enough. + + +## Testing + +To test MANet (stage1, kernel estimation only), run this command: + +```bash +python test.py --opt options/test/test_stage1.yml +``` +To test RRDB-SFT (stage2, non-blind SR with ground-truth kernel), run this command: + +```bash +python test.py --opt options/test/test_stage2.yml +``` +To test MANet+RRDB (stage3, blind SR), run this command: + +```bash +python test.py --opt options/test/test_stage3.yml +``` +Note: above commands generate LR images on-the-fly. To generate testing sets used in the paper, run this command: +```bash +python prepare_testset.py --opt options/test/prepare_testset.yml +``` + +## Interactive Exploration of Kernels +To explore spaitally variant kernels on an image, use `--save_kernel` and run this command to save kernel: + +```bash +python test.py --opt options/test/test_stage1.yml --save_kernel +``` +Then, run this command to creat an interactive window: +```bash +python interactive_explore.py --path ../results/001_MANet_aniso_x4_test_stage1/toy_dataset1/npz/toy1.npz +``` + +## Results +We conducted experiments on both spatially variant and invariant blind SR. Please refer to the [paper](https://arxiv.org/abs/2108.05302) and [supp](https://github.com/JingyunLiang/MANet/releases/tag/v0.0) for results. + +## Citation + @inproceedings{liang21manet, + title={Mutual Affine Network for Spatially Variant Kernel Estimation in Blind Image Super-Resolution}, + author={Liang, Jingyun and Sun, Guolei and Zhang, Kai and Van Gool, Luc and Timofte, Radu}, + booktitle={IEEE Conference on International Conference on Computer Vision}, + year={2021} + } + +## License & Acknowledgement + +This project is released under the Apache 2.0 license. The codes are based on [BasicSR](https://github.com/xinntao/BasicSR), [MMSR](https://github.com/open-mmlab/mmediting), [IKC](https://github.com/yuanjunchai/IKC) and [KAIR](https://github.com/cszn/KAIR). Please also follow their licenses. Thanks for their great works. + + + diff --git a/codes/data/GTLQ_dataset.py b/codes/data/GTLQ_dataset.py new file mode 100644 index 0000000..2154597 --- /dev/null +++ b/codes/data/GTLQ_dataset.py @@ -0,0 +1,136 @@ +import random +import numpy as np +import cv2 +import lmdb +import torch +import torch.utils.data as data +import data.util as util +import sys +import os + +try: + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from data.util import imresize_np + from utils import util as utils +except ImportError: + pass + + +class GTLQDataset(data.Dataset): + ''' + Load HR-LR image pairs. + ''' + + def __init__(self, opt): + super(GTLQDataset, self).__init__() + self.opt = opt + self.LR_paths, self.GT_paths = None, None + self.LR_env, self.GT_env = None, None # environment for lmdb + self.LR_size, self.GT_size = opt['LR_size'], opt['GT_size'] + + # read image list from lmdb or image files + if opt['data_type'] == 'lmdb': + self.LR_paths, self.LR_sizes = util.get_image_paths(opt['data_type'], opt['dataroot_LQ']) + self.GT_paths, self.GT_sizes = util.get_image_paths(opt['data_type'], opt['dataroot_GT']) + elif opt['data_type'] == 'img': + self.LR_paths = util.get_image_paths(opt['data_type'], opt['dataroot_LQ']) # LR list + self.GT_paths = util.get_image_paths(opt['data_type'], opt['dataroot_GT']) # GT list + else: + print('Error: data_type is not matched in Dataset') + assert self.GT_paths, 'Error: GT paths are empty.' + if self.LR_paths and self.GT_paths: + assert len(self.LR_paths) == len( + self.GT_paths), 'GT and LR datasets have different number of images - {}, {}.'.format( + len(self.LR_paths), len(self.GT_paths)) + self.random_scale_list = [1] + + def _init_lmdb(self): + # https://github.com/chainer/chainermn/issues/129 + self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, + meminit=False) + if self.opt['dataroot_LQ'] is not None: + self.LR_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, + meminit=False) + else: + self.LR_env = 'No lmdb input for LR' + + def __getitem__(self, index): + if self.opt['data_type'] == 'lmdb': + if (self.GT_env is None) or (self.LR_env is None): + self._init_lmdb() + + GT_path, LR_path = None, None + scale = self.opt['scale'] + GT_size = self.opt['GT_size'] + + # get GT image + GT_path = self.GT_paths[index] + LR_path = self.LR_paths[index] + if self.opt['data_type'] == 'lmdb': + resolution = [int(s) for s in self.GT_sizes[index].split('_')] + else: + resolution = None + img_GT = util.read_img(self.GT_env, GT_path, resolution) # return: Numpy float32, HWC, BGR, [0,1] + + # modcrop the image + img_GT = util.modcrop(img_GT, scale) + + # get LR image + if self.LR_paths: # LR exist + img_LR = util.read_img(self.LR_env, LR_path, resolution) + + else: # down-sampling on-the-fly + # randomly scale during training + if self.opt['phase'] == 'train': + random_scale = random.choice(self.random_scale_list) + if random_scale != 1: + H_s, W_s, _ = img_GT.shape + H_s = _mod(H_s, random_scale, scale, GT_size) + W_s = _mod(W_s, random_scale, scale, GT_size) + img_GT = cv2.resize(np.copy(img_GT), (W_s, H_s), interpolation=cv2.INTER_LINEAR) + + # force to 3 channels + if img_GT.ndim == 2: + img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR) + + if self.opt['phase'] == 'train': + H, W, C = img_GT.shape + + # randomly crop on HR, more positions than first crop on LR and HR simultaneously + rnd_h_GT = random.randint(0, max(0, H - GT_size)) + rnd_w_GT = random.randint(0, max(0, W - GT_size)) + img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] + + # augmentation - flip, rotate + img_GT = util.augment(img_GT, self.opt['use_flip'], + self.opt['use_rot'], self.opt['mode']) + + # change color space if necessary, deal with gray image + if self.opt['color']: + img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] + img_LR = util.channel_convert(img_LR.shape[2], self.opt['color'], [img_LR])[0] + + # BGR to RGB, HWC to CHW, numpy to tensor + if img_GT.shape[2] == 3: + img_GT = img_GT[:, :, [2, 1, 0]] + if img_LR.shape[2] == 3: + img_LR = img_LR[:, :, [2, 1, 0]] + img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() + img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float() + + if LR_path is None: + LR_path = GT_path + + # # don't need LR because it's generated from HR batches. + # img_LR = torch.ones(1,1,1) + + return {'LQ': img_LR, 'GT': img_GT, 'LQ_path': LR_path, 'GT_path': GT_path} + + def __len__(self): + return len(self.GT_paths) + + +def _mod(n, random_scale, scale, thres): + rlt = int(n * random_scale) + rlt = (rlt // scale) * scale + return thres if rlt < thres else rlt diff --git a/codes/data/GT_dataset.py b/codes/data/GT_dataset.py new file mode 100644 index 0000000..7bd5294 --- /dev/null +++ b/codes/data/GT_dataset.py @@ -0,0 +1,131 @@ +import random +import numpy as np +import cv2 +import lmdb +import torch +import torch.utils.data as data +import data.util as util +import sys +import os + +try: + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from data.util import imresize_np + from utils import util as utils +except ImportError: + pass + + +class GTDataset(data.Dataset): + ''' + Load GT images only. 30s faster than LQGTKer (90s for 200iter). + ''' + + def __init__(self, opt): + super(GTDataset, self).__init__() + self.opt = opt + self.LR_paths, self.GT_paths = None, None + self.LR_env, self.GT_env = None, None # environment for lmdb + self.LR_size, self.GT_size = opt['LR_size'], opt['GT_size'] + + # read image list from lmdb or image files + if opt['data_type'] == 'lmdb': + self.LR_paths, self.LR_sizes = util.get_image_paths(opt['data_type'], opt['dataroot_LQ']) + self.GT_paths, self.GT_sizes = util.get_image_paths(opt['data_type'], opt['dataroot_GT']) + elif opt['data_type'] == 'img': + self.LR_paths = util.get_image_paths(opt['data_type'], opt['dataroot_LQ']) # LR list + self.GT_paths = util.get_image_paths(opt['data_type'], opt['dataroot_GT']) # GT list + else: + print('Error: data_type is not matched in Dataset') + assert self.GT_paths, 'Error: GT paths are empty.' + if self.LR_paths and self.GT_paths: + assert len(self.LR_paths) == len( + self.GT_paths), 'GT and LR datasets have different number of images - {}, {}.'.format( + len(self.LR_paths), len(self.GT_paths)) + self.random_scale_list = [1] + + def _init_lmdb(self): + # https://github.com/chainer/chainermn/issues/129 + self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, + meminit=False) + if self.opt['dataroot_LQ'] is not None: + self.LR_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, + meminit=False) + else: + self.LR_env = 'No lmdb input for LR' + + def __getitem__(self, index): + if self.opt['data_type'] == 'lmdb': + if (self.GT_env is None) or (self.LR_env is None): + self._init_lmdb() + + GT_path, LR_path = None, None + scale = self.opt['scale'] + GT_size = self.opt['GT_size'] + + # get GT image + GT_path = self.GT_paths[index] + if self.opt['data_type'] == 'lmdb': + resolution = [int(s) for s in self.GT_sizes[index].split('_')] + else: + resolution = None + img_GT = util.read_img(self.GT_env, GT_path, resolution) # return: Numpy float32, HWC, BGR, [0,1] + + # modcrop in the validation / test phase + img_GT = util.modcrop(img_GT, scale) + + # get LR image + if self.LR_paths: # LR exist + raise ValueError('GTker_dataset.py doesn Not allow LR input.') + + else: # down-sampling on-the-fly + # randomly scale during training + if self.opt['phase'] == 'train': + random_scale = random.choice(self.random_scale_list) + if random_scale != 1: + H_s, W_s, _ = img_GT.shape + H_s = _mod(H_s, random_scale, scale, GT_size) + W_s = _mod(W_s, random_scale, scale, GT_size) + img_GT = cv2.resize(np.copy(img_GT), (W_s, H_s), interpolation=cv2.INTER_LINEAR) + + # force to 3 channels + if img_GT.ndim == 2: + img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR) + + if self.opt['phase'] == 'train': + H, W, C = img_GT.shape + + # randomly crop on HR, more positions than first crop on LR and HR simultaneously + rnd_h_GT = random.randint(0, max(0, H - GT_size)) + rnd_w_GT = random.randint(0, max(0, W - GT_size)) + img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] + + # augmentation - flip, rotate + img_GT = util.augment(img_GT, self.opt['use_flip'], + self.opt['use_rot'], self.opt['mode']) + + # change color space if necessary, deal with gray image + if self.opt['color']: + img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] + + # BGR to RGB, HWC to CHW, numpy to tensor + if img_GT.shape[2] == 3: + img_GT = img_GT[:, :, [2, 1, 0]] + img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() + + if LR_path is None: + LR_path = GT_path + + # don't need LR because it's generated from HR batches. + img_LR = torch.ones(1, 1, 1) + + return {'LQ': img_LR, 'GT': img_GT, 'LQ_path': LR_path, 'GT_path': GT_path} + + def __len__(self): + return len(self.GT_paths) + + +def _mod(n, random_scale, scale, thres): + rlt = int(n * random_scale) + rlt = (rlt // scale) * scale + return thres if rlt < thres else rlt diff --git a/codes/data/LQ_dataset.py b/codes/data/LQ_dataset.py new file mode 100644 index 0000000..d11ae9c --- /dev/null +++ b/codes/data/LQ_dataset.py @@ -0,0 +1,106 @@ +import random +import numpy as np +import cv2 +import lmdb +import torch +import torch.nn.functional as F +import torch.utils.data as data +import data.util as util +import sys +import os + +try: + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from data.util import imresize_np + from utils import util as utils +except ImportError: + pass + + +class LQDataset(data.Dataset): + ''' + Load LR images only, e.g. real-world images + ''' + + def __init__(self, opt): + super(LQDataset, self).__init__() + self.opt = opt + self.LR_paths, self.GT_paths = None, None + self.LR_env, self.GT_env = None, None # environment for lmdb + self.LR_size, self.GT_size = opt['LR_size'], opt['GT_size'] + + # read image list from lmdb or image files + if opt['data_type'] == 'lmdb': + self.LR_paths, self.LR_sizes = util.get_image_paths(opt['data_type'], opt['dataroot_LQ']) + self.GT_paths, self.GT_sizes = util.get_image_paths(opt['data_type'], opt['dataroot_GT']) + elif opt['data_type'] == 'img': + self.LR_paths = util.get_image_paths(opt['data_type'], opt['dataroot_LQ']) # LR list + self.GT_paths = util.get_image_paths(opt['data_type'], opt['dataroot_GT']) # GT list + else: + print('Error: data_type is not matched in Dataset') + assert self.LR_paths, 'Error: LQ paths are empty.' + if self.LR_paths and self.GT_paths: + assert len(self.LR_paths) == len( + self.GT_paths), 'GT and LR datasets have different number of images - {}, {}.'.format( + len(self.LR_paths), len(self.GT_paths)) + self.random_scale_list = [1] + + def __getitem__(self, index): + + GT_path, LQ_path = None, None + + # get GT image + LQ_path = self.LR_paths[index] + img_LQ = util.read_img(None, LQ_path, None) # return: Numpy float32, HWC, BGR, [0,1] + + if self.GT_paths: # LR exist + raise ValueError('LQ_dataset.py doesn Not allow HR input.') + + else: + # force to 3 channels + if img_LQ.ndim == 2: + img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_GRAY2BGR) + + # change color space if necessary, deal with gray image + if self.opt['color']: + img_LQ = util.channel_convert(img_LQ.shape[2], self.opt['color'], [img_LQ])[0] + + # BGR to RGB, HWC to CHW, numpy to tensor + if img_LQ.shape[2] == 3: + img_LQ = img_LQ[:, :, [2, 1, 0]] + img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() + + if GT_path is None: + GT_path = LQ_path + + # don't need LR because it's generated from HR batches. + img_GT = torch.ones(1, 1, 1) + + # deal with the image margins for real-world images + img_LQ = img_LQ.unsqueeze(0) + x_gt = F.interpolate(img_LQ, scale_factor=self.opt['scale'], mode='nearest') + if self.opt['scale'] == 4: + real_crop = 3 + elif self.opt['scale'] == 2: + real_crop = 6 + elif self.opt['scale'] == 1: + real_crop = 11 + assert real_crop * self.opt['scale'] * 2 > self.opt['kernel_size'] + x_gt = F.pad(x_gt, pad=( + real_crop * self.opt['scale'], real_crop * self.opt['scale'], real_crop * self.opt['scale'], + real_crop * self.opt['scale']), mode='replicate') # 'constant', 'reflect', 'replicate' or 'circular + + kernel_gt, sigma_gt = utils.stable_batch_kernel(1, l=self.opt['kernel_size'], sig=10, sig1=0, sig2=0, + theta=0, rate_iso=1, scale=self.opt['scale'], + tensor=True) # generate kernel [BHW], y [BCHW] + + blur_layer = utils.BatchBlur(l=self.opt['kernel_size'], padmode='zero') + sample_layer = utils.BatchSubsample(scale=self.opt['scale']) + y_blurred = sample_layer(blur_layer(x_gt, kernel_gt)) + y_blurred[:, :, real_crop:-real_crop, real_crop:-real_crop] = img_LQ + img_LQ = y_blurred.squeeze(0) + + return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path} + + def __len__(self): + return len(self.LR_paths) diff --git a/codes/data/__init__.py b/codes/data/__init__.py new file mode 100644 index 0000000..1347999 --- /dev/null +++ b/codes/data/__init__.py @@ -0,0 +1,45 @@ +'''create dataset and dataloader''' +import logging +import torch +import torch.utils.data + + +def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): + phase = dataset_opt['phase'] + if phase == 'train': + if opt['dist']: + world_size = torch.distributed.get_world_size() + num_workers = dataset_opt['n_workers'] + assert dataset_opt['batch_size'] % world_size == 0 + batch_size = dataset_opt['batch_size'] // world_size + shuffle = False + else: + num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids']) + batch_size = dataset_opt['batch_size'] + shuffle = True + return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, + num_workers=num_workers, sampler=sampler, drop_last=True, + pin_memory=True) + else: + return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, + pin_memory=True) + + +def create_dataset(dataset_opt): + mode = dataset_opt['mode'] + if mode == 'GT': # load HR image and generate LR on-the-fly + from data.GT_dataset import GTDataset as D + dataset = D(dataset_opt) + elif mode == 'GTLQ': # load generated HR-LR image pairs + from data.GTLQ_dataset import GTLQDataset as D + dataset = D(dataset_opt) + elif mode == 'LQ': # load LR image for testing + from data.LQ_dataset import LQDataset as D + dataset = D(dataset_opt) + else: + raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) + + logger = logging.getLogger('base') + logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__, + dataset_opt['name'])) + return dataset diff --git a/codes/data/data_sampler.py b/codes/data/data_sampler.py new file mode 100644 index 0000000..f55d682 --- /dev/null +++ b/codes/data/data_sampler.py @@ -0,0 +1,65 @@ +""" +Modified from torch.utils.data.distributed.DistributedSampler +Support enlarging the dataset for *iter-oriented* training, for saving time when restart the +dataloader after each epoch +""" +import math +import torch +from torch.utils.data.sampler import Sampler +import torch.distributed as dist + + +class DistIterSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + + .. note:: + Dataset is assumed to be of constant size. + + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + """ + + def __init__(self, dataset, num_replicas=None, rank=None, ratio=100): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(self.total_size, generator=g).tolist() #Returns a random permutation of integers from 0 to n - 1 + + dsize = len(self.dataset) + indices = [v % dsize for v in indices] + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/codes/data/util.py b/codes/data/util.py new file mode 100644 index 0000000..6aaf0b3 --- /dev/null +++ b/codes/data/util.py @@ -0,0 +1,497 @@ +import os +import math +import pickle +import random +import numpy as np +import torch +import cv2 + +#################### +# Files & IO +#################### + +###################### get image path list ###################### +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP'] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def _get_paths_from_images(path): + '''get image path list from image folder''' + assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, '{:s} has no valid image file'.format(path) + return images + + +def _get_paths_from_lmdb(dataroot): + '''get image path list from lmdb meta info''' + meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), 'rb')) + paths = meta_info['keys'] + sizes = meta_info['resolution'] + if len(sizes) == 1: + sizes = sizes * len(paths) + return paths, sizes + + +def get_image_paths(data_type, dataroot): + '''get image path list + support lmdb or image files''' + paths, sizes = None, None + if data_type == 'lmdb': + if dataroot is not None: + paths, sizes = _get_paths_from_lmdb(dataroot) + return paths, sizes + elif data_type == 'img': + if dataroot is not None: + paths = sorted(_get_paths_from_images(dataroot)) + return paths + else: + raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type)) + + +###################### read images ###################### +def _read_img_lmdb(env, key, size): + '''read image from lmdb with key (w/ and w/o fixed size) + size: (C, H, W) tuple''' + with env.begin(write=False) as txn: + buf = txn.get(key.encode('ascii')) + img_flat = np.frombuffer(buf, dtype=np.uint8) + C, H, W = size + img = img_flat.reshape(H, W, C) + return img + + +def read_img(env, path, size=None): + '''read image by cv2 or from lmdb + return: Numpy float32, HWC, BGR, [0,1]''' + if env is None: # img + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + else: + img = _read_img_lmdb(env, path, size) + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) # added by jinliang + # img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + +def read_img_fromcv2(img): + '''transform image in cv2 to the same output as read_img + return: Numpy float32, HWC, BGR, [0,1]''' + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + +#################### +# image processing +# process on numpy image +#################### + + +def augment(img, hflip=True, rot=True, mode=None): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + if mode == 'LQ' or mode == 'GT' or mode == 'SRker' or mode == 'GTker' or mode == 'GTker_memory': + return _augment(img) + elif mode == 'LQGTker': + return [_augment(I) for I in img] + else: + raise NotImplementedError('{} dataloader has not been implemented yet.'.format(mode)) + + +def augment_flow(img_list, flow_list, hflip=True, rot=True): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + def _augment_flow(flow): + if hflip: + flow = flow[:, ::-1, :] + flow[:, :, 0] *= -1 + if vflip: + flow = flow[::-1, :, :] + flow[:, :, 1] *= -1 + if rot90: + flow = flow.transpose(1, 0, 2) + flow = flow[:, :, [1, 0]] + return flow + + rlt_img_list = [_augment(img) for img in img_list] + rlt_flow_list = [_augment_flow(flow) for flow in flow_list] + + return rlt_img_list, rlt_flow_list + + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + if in_c == 3 and tar_type == 'gray': # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'y': # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +def rgb2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + '''bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + '''same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + +def modcrop(img_in, scale): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r, :] + else: + raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + return img + + +#################### +# Functions +#################### + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5 * absx3 - 2.5 * absx2 + 1) * ( + (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (( + (absx > 1) * (absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( + 1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: CHW RGB [0,1] + # output: CHW RGB [0,1] w/o round + + in_C, in_H, in_W = img.size() + _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + out_1[0, i, :] = img_aug[0, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + out_1[1, i, :] = img_aug[1, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + out_1[2, i, :] = img_aug[2, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + out_2[0, :, i] = out_1_aug[0, :, idx:idx + kernel_width].mv(weights_W[i]) + out_2[1, :, i] = out_1_aug[1, :, idx:idx + kernel_width].mv(weights_W[i]) + out_2[2, :, i] = out_1_aug[2, :, idx:idx + kernel_width].mv(weights_W[i]) + + return out_2 + + +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC BGR [0,1] + # output: HWC BGR [0,1] w/o round + img = torch.from_numpy(img) + + in_H, in_W, in_C = img.size() + _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + out_1[i, :, 0] = img_aug[idx:idx + kernel_width, :, 0].transpose(0, 1).mv(weights_H[i]) + out_1[i, :, 1] = img_aug[idx:idx + kernel_width, :, 1].transpose(0, 1).mv(weights_H[i]) + out_1[i, :, 2] = img_aug[idx:idx + kernel_width, :, 2].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + out_2[:, i, 0] = out_1_aug[:, idx:idx + kernel_width, 0].mv(weights_W[i]) + out_2[:, i, 1] = out_1_aug[:, idx:idx + kernel_width, 1].mv(weights_W[i]) + out_2[:, i, 2] = out_1_aug[:, idx:idx + kernel_width, 2].mv(weights_W[i]) + + return out_2.numpy() + + + ### Load data kernel map ### +def load_ker_map_list(path): + real_ker_map_list = [] + batch_kermap = torch.load(path) + size_kermap = batch_kermap.size() + m = size_kermap[0] + for i in range(m): + real_ker_map_list.append(batch_kermap[i]) + + return real_ker_map_list + + +if __name__ == '__main__': + # test imresize function + # read images + img = cv2.imread('test.png') + img = img * 1.0 / 255 + img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() + # imresize + scale = 1 / 4 + import time + total_time = 0 + for i in range(10): + start_time = time.time() + rlt = imresize(img, scale, antialiasing=True) + use_time = time.time() - start_time + total_time += use_time + print('average time: {}'.format(total_time / 10)) + + import torchvision.utils + torchvision.utils.save_image((rlt * 255).round() / 255, 'rlt.png', nrow=1, padding=0, + normalize=False) diff --git a/codes/interactive_explore.py b/codes/interactive_explore.py new file mode 100644 index 0000000..2da5396 --- /dev/null +++ b/codes/interactive_explore.py @@ -0,0 +1,82 @@ +# https://matplotlib.org/3.1.0/gallery/text_labels_and_annotations/demo_annotation_box.html + +from matplotlib import pyplot as plt +from matplotlib.offsetbox import OffsetImage, AnnotationBbox +import numpy as np +import argparse + + +def main(): + # Parse the command line arguments + prog = argparse.ArgumentParser() + prog.add_argument('--path', type=str, default='../results/001_MANet_aniso_x4_test_stage1/toy_dataset/npz/toy1.npz', + help='path to kernel estimation npz file') + args = prog.parse_args() + + data = np.load(args.path) + + sr_img = data['sr_img'][:, :, [2, 1, 0]] + est_ker_sv = data['est_ker_sv'] + gt_k_np = data['gt_ker'] + + # create figure and plot scatter + fig = plt.figure() + + if gt_k_np.sum() == 0: + ax = fig.add_subplot(111) + im = ax.imshow(sr_img) + else: + + ax = fig.add_subplot(121) + im = ax.imshow(gt_k_np, vmin=gt_k_np.min(), vmax=gt_k_np.max()) + plt.colorbar(im, ax=ax) + ax.set_title('GT kernel') + + ax = fig.add_subplot(122) + im = ax.imshow(sr_img) + ax.set_title('View kernel estimation\n by hovering the cursor') + + # create the annotations box + image = OffsetImage(np.random.rand(13, 13), zoom=5) + xybox = (100., 100.) + ab = AnnotationBbox(image, (0, 0), xybox=xybox, xycoords='data', + boxcoords="offset points", pad=0.3, arrowprops=dict(arrowstyle="->")) + + # add it to the axes and make it invisible + ax.add_artist(ab) + ab.set_visible(False) + + def hover(event): + # if the mouse is over the scatter points + if im.contains(event)[0]: + + # get the figure size + w, h = fig.get_size_inches() * fig.dpi + ws = (event.x > w / 2.) * -1 + (event.x <= w / 2.) + hs = (event.y > h / 2.) * -1 + (event.y <= h / 2.) + # if event occurs in the top or right quadrant of the figure, + # change the annotation box position relative to mouse. + ab.xybox = (xybox[0] * ws, xybox[1] * hs) + ab.xybox = (-150, 0) + # make annotation box visible + ab.set_visible(True) + # place it at the position of the hovered scatter point + ab.xy = (int(event.xdata), int(event.ydata)) + # set the image corresponding to that point + data = est_ker_sv[int(event.xdata) + int(event.ydata) * sr_img.shape[1], :, :] + data = data / data.max() + image.set_data(data) + + + else: + # if the mouse is not over a scatter point + ab.set_visible(False) + fig.canvas.draw_idle() + + # add callback for mouse moves + fig.canvas.mpl_connect('motion_notify_event', hover) + plt.show() + + +if __name__ == '__main__': + main() diff --git a/codes/models/B_model.py b/codes/models/B_model.py new file mode 100644 index 0000000..80d796f --- /dev/null +++ b/codes/models/B_model.py @@ -0,0 +1,290 @@ +# base model for blind SR, input LR, output kernel + SR +import logging +from collections import OrderedDict +import torch +import torch.nn as nn +import torch.nn.init as init +from torch.nn.parallel import DataParallel, DistributedDataParallel +from torch.cuda.amp import autocast as autocast +import models.networks as networks +import models.lr_scheduler as lr_scheduler +from .base_model import BaseModel +from models.modules.loss import CharbonnierLoss +import utils.util as util + +logger = logging.getLogger('base') + + +class B_Model(BaseModel): + def __init__(self, opt): + super(B_Model, self).__init__(opt) + + if opt['dist']: + self.rank = torch.distributed.get_rank() + else: + self.rank = -1 # non dist training + + # define network and load pretrained models + self.netG = networks.define_G(opt).to(self.device) + if opt['dist']: + self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) + else: + self.netG = DataParallel(self.netG) + # print network + self.print_network() + self.load() + self.load_K() # load the kernel estimation part + + # degradation model + self.degradation_model = DegradationModel(opt['kernel_size'], opt['scale'], sv_degradation=True) + self.degradation_model = nn.DataParallel(self.degradation_model) + self.cal_lr_psnr = opt['cal_lr_psnr'] + + if self.is_train: + train_opt = opt['train'] + self.netG.train() + + # HR loss + loss_type = train_opt['pixel_criterion'] + if loss_type == 'l1': + self.cri_pix = nn.L1Loss().to(self.device) + elif loss_type == 'l2': + self.cri_pix = nn.MSELoss().to(self.device) + elif loss_type == 'cb': + self.cri_pix = CharbonnierLoss().to(self.device) + elif loss_type is None: + self.cri_pix = None + else: + raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) + self.l_pix_w = train_opt['pixel_weight'] + + # LR loss + loss_type = train_opt['pixel_criterion_lr'] + if loss_type == 'l1': + self.cri_pix_lr = nn.L1Loss().to(self.device) + elif loss_type == 'l2': + self.cri_pix_lr = nn.MSELoss().to(self.device) + elif loss_type == 'cb': + self.cri_pix_lr = CharbonnierLoss().to(self.device) + elif loss_type is None: + self.cri_pix_lr = None + else: + raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) + self.l_pix_w_lr = train_opt['pixel_weight_lr'] + + # kernel loss + loss_type = train_opt['kernel_criterion'] + if loss_type == 'l1': + self.cri_ker = nn.L1Loss().to(self.device) + elif loss_type == 'l2': + self.cri_ker = nn.MSELoss().to(self.device) + elif loss_type == 'cb': + self.cri_ker = CharbonnierLoss().to(self.device) + elif loss_type is None: + self.cri_ker = None + else: + raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) + self.l_ker_w = train_opt['kernel_weight'] + + # optimizers + wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 + optim_params = [] + for k, v in self.netG.named_parameters(): # can optimize for a part of the model + if v.requires_grad: + optim_params.append(v) + else: + if self.rank <= 0: + logger.warning('Params [{:s}] will not optimize.'.format(k)) + + self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], + weight_decay=wd_G, + betas=(train_opt['beta1'], train_opt['beta2'])) + self.optimizers.append(self.optimizer_G) + + # schedulers + if train_opt['lr_scheme'] == 'MultiStepLR': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], + restarts=train_opt['restarts'], + weights=train_opt['restart_weights'], + gamma=train_opt['lr_gamma'], + clear_state=train_opt['clear_state'])) + elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.CosineAnnealingLR_Restart( + optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], + restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) + else: + print('MultiStepLR learning rate scheme is enough.') + + self.log_dict = OrderedDict() + + def init_model(self, scale=0.1): + # Common practise for initialization. + for layer in self.netG.modules(): + if isinstance(layer, nn.Conv2d): + init.kaiming_normal_(layer.weight, a=0, mode='fan_in') + layer.weight.data *= scale # for residual block + if layer.bias is not None: + layer.bias.data.zero_() + elif isinstance(layer, nn.Linear): + init.kaiming_normal_(layer.weight, a=0, mode='fan_in') + layer.weight.data *= scale + if layer.bias is not None: + layer.bias.data.zero_() + elif isinstance(layer, nn.BatchNorm2d): + init.constant_(layer.weight, 1) + init.constant_(layer.bias.data, 0.0) + + def feed_data(self, data, LR_img, LR_n_img, ker_map, kernel): + self.real_H = data['GT'].to(self.device) # GT + self.var_L, self.var_LN, self.ker_map, self.real_K = LR_img.to(self.device), LR_n_img.to( + self.device), ker_map.to(self.device), kernel.to(self.device) + + def optimize_parameters(self, step, scaler): + self.optimizer_G.zero_grad() + + with autocast(): + l_all = 0 + self.fake_SR, self.fake_K = self.netG(self.var_LN, self.real_K) + + # hr loss + if self.cri_pix is not None: + l_pix = self.l_pix_w * self.cri_pix(self.fake_SR, self.real_H) + l_all += l_pix + self.log_dict['l_pix'] = l_pix.item() + + # kernel loss + if self.cri_ker is not None: + # times 1e4 since kernel pixel values are very small + l_ker = self.l_ker_w * self.cri_ker(self.fake_K * 10000, + self.real_K.unsqueeze(1).expand(-1, self.fake_K.size(1), -1, + -1) * 10000) / self.fake_K.size(1) + l_all += l_ker + self.log_dict['l_ker'] = l_ker.item() + + # lr loss + if self.cri_pix_lr is not None: + self.fake_LR = self.degradation_model(self.real_H, self.fake_K) + l_pix_lr = self.l_pix_w_lr * self.cri_pix_lr(self.fake_LR, self.var_L) # we should use LR before noise corruption as a ref + l_all += l_pix_lr + self.log_dict['l_pix_lr'] = l_pix_lr.item() + + else: + self.fake_LR = self.var_L + + scaler.scale(l_all).backward() + scaler.step(self.optimizer_G) + scaler.update() + + def test(self): + self.netG.eval() + with torch.no_grad(): + self.fake_SR, self.fake_K = self.netG(self.var_LN, self.real_K) + + if self.cal_lr_psnr: + # synthesized data + if self.real_H.shape[2] * self.real_H.shape[3] > 1: + self.fake_LR = self.degradation_model(self.real_H, self.fake_K) + # no HR + else: + self.fake_LR = self.degradation_model(self.fake_SR, self.fake_K) + else: + self.fake_LR = self.var_L + + self.netG.train() + + def test_x8(self): + # from https://github.com/thstkdgus35/EDSR-PyTorch + self.netG.eval() + + def _transform(v, op): + # if self.precision != 'single': v = v.float() + v2np = v.data.cpu().numpy() + if op == 'v': + tfnp = v2np[:, :, :, ::-1].copy() + elif op == 'h': + tfnp = v2np[:, :, ::-1, :].copy() + elif op == 't': + tfnp = v2np.transpose((0, 1, 3, 2)).copy() + + ret = torch.Tensor(tfnp).to(self.device) + # if self.precision == 'half': ret = ret.half() + + return ret + + lr_list = [self.var_LN] + for tf in 'v', 'h', 't': + lr_list.extend([_transform(t, tf) for t in lr_list]) + with torch.no_grad(): + sr_list = [self.netG(aug) for aug in lr_list] + for i in range(len(sr_list)): + if i > 3: + sr_list[i] = _transform(sr_list[i], 't') + if i % 4 > 1: + sr_list[i] = _transform(sr_list[i], 'h') + if (i % 4) % 2 == 1: + sr_list[i] = _transform(sr_list[i], 'v') + + output_cat = torch.cat(sr_list, dim=0) + self.fake_SR = output_cat.mean(dim=0, keepdim=True) + self.netG.train() + + def get_current_log(self): + return self.log_dict + + def get_current_visuals(self): + out_dict = OrderedDict() + out_dict['LQN'] = self.var_LN.detach()[0].float().cpu() + out_dict['LQ'] = self.var_L.detach()[0].float().cpu() + out_dict['LQE'] = self.fake_LR.detach()[0].float().cpu() + out_dict['SR'] = self.fake_SR.detach()[0].float().cpu() + out_dict['GT'] = self.real_H.detach()[0].float().cpu() + out_dict['ker_map'] = self.ker_map.detach()[0].float().cpu() + out_dict['KE'] = self.fake_K.detach()[0].float().cpu() + out_dict['K'] = self.real_K.detach()[0].float().cpu() + out_dict['Batch_SR'] = self.fake_SR.detach().float().cpu() # Batch SR, for train + out_dict['Batch_KE'] = self.fake_K.detach().float().cpu() # Batch SR, for train + return out_dict + + def print_network(self): + s, n = self.get_network_description(self.netG) + if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): + net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, + self.netG.module.__class__.__name__) + else: + net_struc_str = '{}'.format(self.netG.__class__.__name__) + if self.rank <= 0: + logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) + logger.info(s) + + def load(self): + load_path_G = self.opt['path']['pretrain_model_G'] + if load_path_G is not None: + logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) + self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) + + def load_K(self): + load_path_K = self.opt['path']['pretrain_model_K'] + if load_path_K is not None: + logger.info('Loading model for K [{:s}] ...'.format(load_path_K)) + self.load_network(load_path_K, self.netG, self.opt['path']['strict_load']) + + def save(self, iter_label): + self.save_network(self.netG, 'G', iter_label) + + +class DegradationModel(nn.Module): + def __init__(self, kernel_size=15, scale=4, sv_degradation=True): + super(DegradationModel, self).__init__() + if sv_degradation: + self.blur_layer = util.BatchBlur_SV(l=kernel_size, padmode='replication') + self.sample_layer = util.BatchSubsample(scale=scale) + else: + self.blur_layer = util.BatchBlur(l=kernel_size, padmode='replication') + self.sample_layer = util.BatchSubsample(scale=scale) + + def forward(self, image, kernel): + return self.sample_layer(self.blur_layer(image, kernel)) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py new file mode 100644 index 0000000..88eb085 --- /dev/null +++ b/codes/models/SRGAN_model.py @@ -0,0 +1,267 @@ +import logging +from collections import OrderedDict +import torch +import torch.nn as nn +from torch.nn.parallel import DataParallel, DistributedDataParallel +import models.networks as networks +import models.lr_scheduler as lr_scheduler +from .base_model import BaseModel +from models.modules.loss import GANLoss + +logger = logging.getLogger('base') + + +class SRGANModel(BaseModel): + def __init__(self, opt): + super(SRGANModel, self).__init__(opt) + if opt['dist']: + self.rank = torch.distributed.get_rank() + else: + self.rank = -1 # non dist training + train_opt = opt['train'] + + # define networks and load pretrained models + self.netG = networks.define_G(opt).to(self.device) + if opt['dist']: + self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) + else: + self.netG = DataParallel(self.netG) + if self.is_train: + self.netD = networks.define_D(opt).to(self.device) + if opt['dist']: + self.netD = DistributedDataParallel(self.netD, + device_ids=[torch.cuda.current_device()]) + else: + self.netD = DataParallel(self.netD) + + self.netG.train() + self.netD.train() + + # define losses, optimizer and scheduler + if self.is_train: + # G pixel loss + if train_opt['pixel_weight'] > 0: + l_pix_type = train_opt['pixel_criterion'] + if l_pix_type == 'l1': + self.cri_pix = nn.L1Loss().to(self.device) + elif l_pix_type == 'l2': + self.cri_pix = nn.MSELoss().to(self.device) + else: + raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) + self.l_pix_w = train_opt['pixel_weight'] + else: + logger.info('Remove pixel loss.') + self.cri_pix = None + + # G feature loss + if train_opt['feature_weight'] > 0: + l_fea_type = train_opt['feature_criterion'] + if l_fea_type == 'l1': + self.cri_fea = nn.L1Loss().to(self.device) + elif l_fea_type == 'l2': + self.cri_fea = nn.MSELoss().to(self.device) + else: + raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) + self.l_fea_w = train_opt['feature_weight'] + else: + logger.info('Remove feature loss.') + self.cri_fea = None + if self.cri_fea: # load VGG perceptual loss + self.netF = networks.define_F(opt, use_bn=False).to(self.device) + if opt['dist']: + self.netF = DistributedDataParallel(self.netF, + device_ids=[torch.cuda.current_device()]) + else: + self.netF = DataParallel(self.netF) + + # GD gan loss + self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) + self.l_gan_w = train_opt['gan_weight'] + # D_update_ratio and D_init_iters + self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 + self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 + + # optimizers + # G + wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 + optim_params = [] + for k, v in self.netG.named_parameters(): # can optimize for a part of the model + if v.requires_grad: + optim_params.append(v) + else: + if self.rank <= 0: + logger.warning('Params [{:s}] will not optimize.'.format(k)) + self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], + weight_decay=wd_G, + betas=(train_opt['beta1_G'], train_opt['beta2_G'])) + self.optimizers.append(self.optimizer_G) + # D + wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 + self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], + weight_decay=wd_D, + betas=(train_opt['beta1_D'], train_opt['beta2_D'])) + self.optimizers.append(self.optimizer_D) + + # schedulers + if train_opt['lr_scheme'] == 'MultiStepLR': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], + restarts=train_opt['restarts'], + weights=train_opt['restart_weights'], + gamma=train_opt['lr_gamma'], + clear_state=train_opt['clear_state'])) + elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.CosineAnnealingLR_Restart( + optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], + restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) + else: + raise NotImplementedError('MultiStepLR learning rate scheme is enough.') + + self.log_dict = OrderedDict() + + self.print_network() # print network + self.load() # load G and D if needed + + def feed_data(self, data, need_GT=True): + self.var_L = data['LQ'].to(self.device) # LQ + if need_GT: + self.var_H = data['GT'].to(self.device) # GT + input_ref = data['ref'] if 'ref' in data else data['GT'] + self.var_ref = input_ref.to(self.device) + + def optimize_parameters(self, step): + # G + for p in self.netD.parameters(): + p.requires_grad = False + + self.optimizer_G.zero_grad() + self.fake_H = self.netG(self.var_L) + + l_g_total = 0 + if step % self.D_update_ratio == 0 and step > self.D_init_iters: + if self.cri_pix: # pixel loss + l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) + l_g_total += l_g_pix + if self.cri_fea: # feature loss + real_fea = self.netF(self.var_H).detach() + fake_fea = self.netF(self.fake_H) + l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) + l_g_total += l_g_fea + + pred_g_fake = self.netD(self.fake_H) + if self.opt['train']['gan_type'] == 'gan': + l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) + elif self.opt['train']['gan_type'] == 'ragan': + pred_d_real = self.netD(self.var_ref).detach() + l_g_gan = self.l_gan_w * ( + self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 + l_g_total += l_g_gan + + l_g_total.backward() + self.optimizer_G.step() + + # D + for p in self.netD.parameters(): + p.requires_grad = True + + self.optimizer_D.zero_grad() + l_d_total = 0 + pred_d_real = self.netD(self.var_ref) + pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G + if self.opt['train']['gan_type'] == 'gan': + l_d_real = self.cri_gan(pred_d_real, True) + l_d_fake = self.cri_gan(pred_d_fake, False) + l_d_total = l_d_real + l_d_fake + elif self.opt['train']['gan_type'] == 'ragan': + l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) + l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) + l_d_total = (l_d_real + l_d_fake) / 2 + + l_d_total.backward() + self.optimizer_D.step() + + # set log + if step % self.D_update_ratio == 0 and step > self.D_init_iters: + if self.cri_pix: + self.log_dict['l_g_pix'] = l_g_pix.item() + if self.cri_fea: + self.log_dict['l_g_fea'] = l_g_fea.item() + self.log_dict['l_g_gan'] = l_g_gan.item() + + self.log_dict['l_d_real'] = l_d_real.item() + self.log_dict['l_d_fake'] = l_d_fake.item() + self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) + self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) + + def test(self): + self.netG.eval() + with torch.no_grad(): + self.fake_H = self.netG(self.var_L) + self.netG.train() + + def get_current_log(self): + return self.log_dict + + def get_current_visuals(self, need_GT=True): + out_dict = OrderedDict() + out_dict['LQ'] = self.var_L.detach()[0].float().cpu() + out_dict['SR'] = self.fake_H.detach()[0].float().cpu() + if need_GT: + out_dict['GT'] = self.var_H.detach()[0].float().cpu() + return out_dict + + def print_network(self): + # Generator + s, n = self.get_network_description(self.netG) + if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): + net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, + self.netG.module.__class__.__name__) + else: + net_struc_str = '{}'.format(self.netG.__class__.__name__) + if self.rank <= 0: + logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) + logger.info(s) + if self.is_train: + # Discriminator + s, n = self.get_network_description(self.netD) + if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD, + DistributedDataParallel): + net_struc_str = '{} - {}'.format(self.netD.__class__.__name__, + self.netD.module.__class__.__name__) + else: + net_struc_str = '{}'.format(self.netD.__class__.__name__) + if self.rank <= 0: + logger.info('Network D structure: {}, with parameters: {:,d}'.format( + net_struc_str, n)) + logger.info(s) + + if self.cri_fea: # F, Perceptual Network + s, n = self.get_network_description(self.netF) + if isinstance(self.netF, nn.DataParallel) or isinstance( + self.netF, DistributedDataParallel): + net_struc_str = '{} - {}'.format(self.netF.__class__.__name__, + self.netF.module.__class__.__name__) + else: + net_struc_str = '{}'.format(self.netF.__class__.__name__) + if self.rank <= 0: + logger.info('Network F structure: {}, with parameters: {:,d}'.format( + net_struc_str, n)) + logger.info(s) + + def load(self): + load_path_G = self.opt['path']['pretrain_model_G'] + if load_path_G is not None: + logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) + self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) + load_path_D = self.opt['path']['pretrain_model_D'] + if self.opt['is_train'] and load_path_D is not None: + logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) + self.load_network(load_path_D, self.netD, self.opt['path']['strict_load']) + + def save(self, iter_step): + self.save_network(self.netG, 'G', iter_step) + self.save_network(self.netD, 'D', iter_step) diff --git a/codes/models/SR_model.py b/codes/models/SR_model.py new file mode 100644 index 0000000..e88937c --- /dev/null +++ b/codes/models/SR_model.py @@ -0,0 +1,186 @@ +import logging +from collections import OrderedDict + +import torch +import torch.nn as nn +from torch.nn.parallel import DataParallel, DistributedDataParallel +import models.networks as networks +import models.lr_scheduler as lr_scheduler +from .base_model import BaseModel +from models.modules.loss import CharbonnierLoss + +logger = logging.getLogger('base') + + +class SRModel(BaseModel): + def __init__(self, opt): + super(SRModel, self).__init__(opt) + + if opt['dist']: + self.rank = torch.distributed.get_rank() + else: + self.rank = -1 # non dist training + train_opt = opt['train'] + + # define network and load pretrained models + self.netG = networks.define_G(opt).to(self.device) + if opt['dist']: + self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) + else: + self.netG = DataParallel(self.netG) + # print network + self.print_network() + self.init_model() + self.load() + + if self.is_train: + self.netG.train() + + # loss + loss_type = train_opt['pixel_criterion'] + if loss_type == 'l1': + self.cri_pix = nn.L1Loss().to(self.device) + elif loss_type == 'l2': + self.cri_pix = nn.MSELoss().to(self.device) + elif loss_type == 'cb': + self.cri_pix = CharbonnierLoss().to(self.device) + else: + raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) + self.l_pix_w = train_opt['pixel_weight'] + + # optimizers + wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 + optim_params = [] + for k, v in self.netG.named_parameters(): # can optimize for a part of the model + if v.requires_grad: + optim_params.append(v) + else: + if self.rank <= 0: + logger.warning('Params [{:s}] will not optimize.'.format(k)) + self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], + weight_decay=wd_G, + betas=(train_opt['beta1'], train_opt['beta2'])) + self.optimizers.append(self.optimizer_G) + + # schedulers + if train_opt['lr_scheme'] == 'MultiStepLR': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], + restarts=train_opt['restarts'], + weights=train_opt['restart_weights'], + gamma=train_opt['lr_gamma'], + clear_state=train_opt['clear_state'])) + elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.CosineAnnealingLR_Restart( + optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], + restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) + else: + raise NotImplementedError('MultiStepLR learning rate scheme is enough.') + + self.log_dict = OrderedDict() + + def init_model(self): + # Common practise for initialization. + for layer in self.netG.modules(): + if isinstance(layer, torch.nn.Conv2d): + torch.nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu') + if layer.bias is not None: + torch.nn.init.constant_(layer.bias, val=0.0) + elif isinstance(layer, torch.nn.BatchNorm2d): + torch.nn.init.constant_(layer.weight, val=1.0) + torch.nn.init.constant_(layer.bias, val=0.0) + elif isinstance(layer, torch.nn.Linear): + torch.nn.init.xavier_normal_(layer.weight) + if layer.bias is not None: + torch.nn.init.constant_(layer.bias, val=0.0) + + def feed_data(self, data, need_GT=True): + self.var_L = data['LQ'].to(self.device) # LQ + if need_GT: + self.real_H = data['GT'].to(self.device) # GT + + def optimize_parameters(self, step): + self.optimizer_G.zero_grad() + self.fake_H = self.netG(self.var_L) + l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) + l_pix.backward() + self.optimizer_G.step() + + # set log + self.log_dict['l_pix'] = l_pix.item() + + def test(self): + self.netG.eval() + with torch.no_grad(): + self.fake_H = self.netG(self.var_L) + self.netG.train() + + def test_x8(self): + # from https://github.com/thstkdgus35/EDSR-PyTorch + self.netG.eval() + + def _transform(v, op): + # if self.precision != 'single': v = v.float() + v2np = v.data.cpu().numpy() + if op == 'v': + tfnp = v2np[:, :, :, ::-1].copy() + elif op == 'h': + tfnp = v2np[:, :, ::-1, :].copy() + elif op == 't': + tfnp = v2np.transpose((0, 1, 3, 2)).copy() + + ret = torch.Tensor(tfnp).to(self.device) + # if self.precision == 'half': ret = ret.half() + + return ret + + lr_list = [self.var_L] + for tf in 'v', 'h', 't': + lr_list.extend([_transform(t, tf) for t in lr_list]) + with torch.no_grad(): + sr_list = [self.netG(aug) for aug in lr_list] + for i in range(len(sr_list)): + if i > 3: + sr_list[i] = _transform(sr_list[i], 't') + if i % 4 > 1: + sr_list[i] = _transform(sr_list[i], 'h') + if (i % 4) % 2 == 1: + sr_list[i] = _transform(sr_list[i], 'v') + + output_cat = torch.cat(sr_list, dim=0) + self.fake_H = output_cat.mean(dim=0, keepdim=True) + self.netG.train() + + def get_current_log(self): + return self.log_dict + + def get_current_visuals(self, need_GT=True): + out_dict = OrderedDict() + out_dict['LQ'] = self.var_L.detach()[0].float().cpu() + out_dict['SR'] = self.fake_H.detach()[0].float().cpu() + if need_GT: + out_dict['GT'] = self.real_H.detach()[0].float().cpu() + return out_dict + + def print_network(self): + s, n = self.get_network_description(self.netG) + if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): + net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, + self.netG.module.__class__.__name__) + else: + net_struc_str = '{}'.format(self.netG.__class__.__name__) + if self.rank <= 0: + logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) + logger.info(s) + + def load(self): + load_path_G = self.opt['path']['pretrain_model_G'] + if load_path_G is not None: + logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) + self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) + + def save(self, iter_label): + self.save_network(self.netG, 'G', iter_label) diff --git a/codes/models/__init__.py b/codes/models/__init__.py new file mode 100644 index 0000000..e7de6ca --- /dev/null +++ b/codes/models/__init__.py @@ -0,0 +1,18 @@ +import logging +logger = logging.getLogger('base') + + +def create_model(opt): + model = opt['model'] + + if model == 'sr': + from .SR_model import SRModel as M + elif model == 'srgan': + from .SRGAN_model import SRGANModel as M + elif model == 'blind': + from .B_model import B_Model as M + else: + raise NotImplementedError('Model [{:s}] not recognized.'.format(model)) + m = M(opt) + logger.info('Model [{:s}] is created.'.format(m.__class__.__name__)) + return m diff --git a/codes/models/base_model.py b/codes/models/base_model.py new file mode 100644 index 0000000..7477669 --- /dev/null +++ b/codes/models/base_model.py @@ -0,0 +1,127 @@ +import os +from collections import OrderedDict +import torch +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel + + +class BaseModel(): + def __init__(self, opt): + self.opt = opt + self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu') + self.is_train = opt['is_train'] + self.schedulers = [] + self.optimizers = [] + + def feed_data(self, data): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + pass + + def get_current_losses(self): + pass + + def print_network(self): + pass + + def save(self, label): + pass + + def load(self): + pass + + def _set_lr(self, lr_groups_l): + ''' set learning rate for warmup, + lr_groups_l: list for lr_groups. each for a optimizer''' + for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): + for param_group, lr in zip(optimizer.param_groups, lr_groups): + param_group['lr'] = lr + + def _get_init_lr(self): + # get the initial lr, which is set by the scheduler + init_lr_groups_l = [] + for optimizer in self.optimizers: + init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) + return init_lr_groups_l + + def update_learning_rate(self, cur_iter, warmup_iter=-1): + for scheduler in self.schedulers: + scheduler.step() + #### set up warm up learning rate + if cur_iter < warmup_iter: + # get initial lr for each group + init_lr_g_l = self._get_init_lr() + # modify warming-up learning rates + warm_up_lr_l = [] + for init_lr_g in init_lr_g_l: + warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) + # set learning rate + self._set_lr(warm_up_lr_l) + + def get_current_learning_rate(self): + # return self.schedulers[0].get_lr()[0] + return self.optimizers[0].param_groups[0]['lr'] + + def get_network_description(self, network): + '''Get the string and total parameters of the network''' + if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): + network = network.module + s = str(network) + n = sum(map(lambda x: x.numel(), network.parameters())) + return s, n + + def save_network(self, network, network_label, iter_label): + save_filename = '{}_{}.pth'.format(iter_label, network_label) + save_path = os.path.join(self.opt['path']['models'], save_filename) + if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): + network = network.module + state_dict = network.state_dict() + for key, param in state_dict.items(): + state_dict[key] = param.cpu() + torch.save(state_dict, save_path) + + def load_network(self, load_path, network, strict=True): + if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): + network = network.module + load_net = torch.load(load_path) + load_net_clean = OrderedDict() # remove unnecessary 'module.' + for k, v in load_net.items(): + if k.startswith('module.'): + load_net_clean[k[7:]] = v + else: + load_net_clean[k] = v + network.load_state_dict(load_net_clean, strict=strict) + + def save_training_state(self, epoch, iter_step): + '''Saves training state during training, which will be used for resuming''' + state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []} + for s in self.schedulers: + state['schedulers'].append(s.state_dict()) + for o in self.optimizers: + state['optimizers'].append(o.state_dict()) + save_filename = '{}.state'.format(iter_step) + save_path = os.path.join(self.opt['path']['training_state'], save_filename) + torch.save(state, save_path) + + def resume_training(self, resume_state): + '''Resume the optimizers and schedulers for training''' + resume_optimizers = resume_state['optimizers'] + resume_schedulers = resume_state['schedulers'] + assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' + assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' + for i, o in enumerate(resume_optimizers): + self.optimizers[i].load_state_dict(o) + + for i, s in enumerate(resume_schedulers): + # manually change lr milestones + # from collections import Counter + # s['milestones'] = Counter([100000, 150000, 200000, 250000, 300000, 350000, 400000]) # for multistage lr + # s['restarts'] = [120001, 240001, 360001] # for cosine_restart_lr + + self.schedulers[i].load_state_dict(s) + + diff --git a/codes/models/lr_scheduler.py b/codes/models/lr_scheduler.py new file mode 100644 index 0000000..2304f0c --- /dev/null +++ b/codes/models/lr_scheduler.py @@ -0,0 +1,144 @@ +import math +from collections import Counter +from collections import defaultdict +import torch +from torch.optim.lr_scheduler import _LRScheduler + + +class MultiStepLR_Restart(_LRScheduler): + def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, + clear_state=False, last_epoch=-1): + self.milestones = Counter(milestones) + self.gamma = gamma + self.clear_state = clear_state + self.restarts = restarts if restarts else [0] + self.restarts = [v + 1 for v in self.restarts] + self.restart_weights = weights if weights else [1] + assert len(self.restarts) == len( + self.restart_weights), 'restarts and their weights do not match.' + super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch in self.restarts: + if self.clear_state: + self.optimizer.state = defaultdict(dict) + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [group['initial_lr'] * weight for group in self.optimizer.param_groups] + if self.last_epoch not in self.milestones: + return [group['lr'] for group in self.optimizer.param_groups] + return [ + group['lr'] * self.gamma**self.milestones[self.last_epoch] + for group in self.optimizer.param_groups + ] + + +class CosineAnnealingLR_Restart(_LRScheduler): + def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1): + self.T_period = T_period + self.T_max = self.T_period[0] # current T period + self.eta_min = eta_min + self.restarts = restarts if restarts else [0] + self.restarts = [v + 1 for v in self.restarts] + self.restart_weights = weights if weights else [1] + self.last_restart = 0 + assert len(self.restarts) == len( + self.restart_weights), 'restarts and their weights do not match.' + super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch == 0: + return self.base_lrs + elif self.last_epoch in self.restarts: + self.last_restart = self.last_epoch + self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [group['initial_lr'] * weight for group in self.optimizer.param_groups] + elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0: + return [ + group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / + (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) * + (group['lr'] - self.eta_min) + self.eta_min + for group in self.optimizer.param_groups] + + +if __name__ == "__main__": + optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0, + betas=(0.9, 0.99)) + ############################## + # MultiStepLR_Restart + ############################## + ## Original + lr_steps = [200000, 400000, 600000, 800000] + restarts = None + restart_weights = None + + ## two + lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000] + restarts = [500000] + restart_weights = [1] + + ## four + lr_steps = [ + 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000, + 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000 + ] + restarts = [250000, 500000, 750000] + restart_weights = [1, 1, 1] + + scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5, + clear_state=False) + + ############################## + # Cosine Annealing Restart + ############################## + ## two + T_period = [500000, 500000] + restarts = [500000] + restart_weights = [1] + + ## four + T_period = [250000, 250000, 250000, 250000] + restarts = [250000, 500000, 750000] + restart_weights = [1,1, 1] + + scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts, + weights=restart_weights) + + ############################## + # Draw figure + ############################## + N_iter = 1000000 + lr_l = list(range(N_iter)) + for i in range(N_iter): + scheduler.step() + current_lr = optimizer.param_groups[0]['lr'] + lr_l[i] = current_lr + + import matplotlib as mpl + from matplotlib import pyplot as plt + import matplotlib.ticker as mtick + mpl.style.use('default') + import seaborn + seaborn.set(style='whitegrid') + seaborn.set_context('paper') + + plt.figure(1) + plt.subplot(111) + plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0)) + plt.title('Title', fontsize=16, color='k') + plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme') + legend = plt.legend(loc='upper right', shadow=False) + ax = plt.gca() + labels = ax.get_xticks().tolist() + for k, v in enumerate(labels): + labels[k] = str(int(v / 1000)) + 'K' + ax.set_xticklabels(labels) + ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) + + ax.set_ylabel('Learning rate') + ax.set_xlabel('Iteration') + fig = plt.gcf() + plt.show() diff --git a/codes/models/lr_scheduler_test.py b/codes/models/lr_scheduler_test.py new file mode 100644 index 0000000..49b27f0 --- /dev/null +++ b/codes/models/lr_scheduler_test.py @@ -0,0 +1,101 @@ +import math +from collections import Counter +from collections import defaultdict +import torch +from torch.optim.lr_scheduler import _LRScheduler + + +class MultiStepGradient_Restart(_LRScheduler): + def __init__(self, milestones, restarts=None, weights=None, gamma=0.1, + clear_state=False, last_epoch=-1): + self.milestones = Counter(milestones) + self.gamma = gamma + self.clear_state = clear_state + self.restarts = restarts if restarts else [0] + self.restarts = [v + 1 for v in self.restarts] + self.restart_weights = weights if weights else [1] + assert len(self.restarts) == len( + self.restart_weights), 'restarts and their weights do not match.' + + def get_lr(self): + if self.last_epoch in self.restarts: + if self.clear_state: + self.optimizer.state = defaultdict(dict) + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [group['initial_lr'] * weight for group in self.optimizer.param_groups] + if self.last_epoch not in self.milestones: + return [group['lr'] for group in self.optimizer.param_groups] + return [ + group['lr'] * self.gamma**self.milestones[self.last_epoch] + for group in self.optimizer.param_groups + ] + + +class CosineAnnealingGradient_Restart(): + def __init__(self, T_period, restarts=None, weights=None, last_epoch=-1): + self.T_period = T_period + self.T_max = self.T_period[0] # current T period + self.restarts = restarts if restarts else [0] + self.restarts = [v + 1 for v in self.restarts] + self.restart_weights = weights if weights else [1] + self.last_restart = 0 + assert len(self.restarts) == len( + self.restart_weights), 'restarts and their weights do not match.' + self.last_epoch = last_epoch + self.lr = self.restart_weights[0] + + def get_weight(self, last_epoch): + if last_epoch == 0: + return self.lr + elif last_epoch in self.restarts: + self.last_restart = last_epoch + self.T_max = self.T_period[self.restarts.index(last_epoch)] + self.lr = self.restart_weights[self.restarts.index(last_epoch)] + return self.lr + self.lr *= (1 + math.cos(math.pi * (last_epoch - self.last_restart) / self.T_max)) /\ + (1 + math.cos(math.pi * ((last_epoch - self.last_restart) - 1) / self.T_max)) + return self.lr + + +if __name__ == "__main__": + + ############################## + # Cosine Annealing Restart + ############################## + + scheduler = CosineAnnealingGradient_Restart(T_period=[250, 250, 250, 250], restarts=[0,250,500,750], + weights=[1,1,1,1], last_epoch=0) + + ############################## + # Draw figure + ############################## + N_iter = 1000 + lr_l = list(range(N_iter)) + for i in range(N_iter): + lr_l[i] = scheduler.get_weight(i) + + import matplotlib as mpl + from matplotlib import pyplot as plt + import matplotlib.ticker as mtick + mpl.style.use('default') + import seaborn + seaborn.set(style='whitegrid') + seaborn.set_context('paper') + + plt.figure(1) + plt.subplot(111) + plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0)) + plt.title('Title', fontsize=16, color='k') + plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme') + legend = plt.legend(loc='upper right', shadow=False) + ax = plt.gca() + labels = ax.get_xticks().tolist() + for k, v in enumerate(labels): + labels[k] = str(int(v / 1000)) + 'K' + ax.set_xticklabels(labels) + ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) + + ax.set_ylabel('Learning rate') + ax.set_xlabel('Iteration') + fig = plt.gcf() + plt.show() diff --git a/codes/models/modules/MANet_arch.py b/codes/models/modules/MANet_arch.py new file mode 100644 index 0000000..4f2876b --- /dev/null +++ b/codes/models/modules/MANet_arch.py @@ -0,0 +1,301 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import functools +from collections import OrderedDict +import models.modules.module_util as mutil + + +def sequential(*args): + """Advanced nn.Sequential. + + Args: + nn.Sequential, nn.Module + + Returns: + nn.Sequential + """ + if len(args) == 1: + if isinstance(args[0], OrderedDict): + raise NotImplementedError('sequential does not support OrderedDict input.') + return args[0] # No sequential is needed. + modules = [] + for module in args: + if isinstance(module, nn.Sequential): + for submodule in module.children(): + modules.append(submodule) + elif isinstance(module, nn.Module): + modules.append(module) + return nn.Sequential(*modules) + + +# -------------------------------------------- +# MAConv and MABlock for MANet +# -------------------------------------------- + +class MAConv(nn.Module): + ''' Mutual Affine Convolution (MAConv) layer ''' + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias, split=2, reduction=2): + super(MAConv, self).__init__() + assert split >= 2, 'Num of splits should be larger than one' + + self.num_split = split + splits = [1 / split] * split + self.in_split, self.in_split_rest, self.out_split = [], [], [] + + for i in range(self.num_split): + in_split = round(in_channels * splits[i]) if i < self.num_split - 1 else in_channels - sum(self.in_split) + in_split_rest = in_channels - in_split + out_split = round(out_channels * splits[i]) if i < self.num_split - 1 else in_channels - sum(self.out_split) + + self.in_split.append(in_split) + self.in_split_rest.append(in_split_rest) + self.out_split.append(out_split) + + setattr(self, 'fc{}'.format(i), nn.Sequential(*[ + nn.Conv2d(in_split_rest, int(in_split_rest // reduction), 1, 1, 0, True), + nn.ReLU(inplace=True), + nn.Conv2d(int(in_split_rest // reduction), in_split * 2, 1, 1, 0, True), + ])) + setattr(self, 'conv{}'.format(i), nn.Conv2d(in_split, out_split, kernel_size, stride, padding, bias)) + + def forward(self, input): + input = torch.split(input, self.in_split, dim=1) + output = [] + + for i in range(self.num_split): + scale, translation = torch.split(getattr(self, 'fc{}'.format(i))(torch.cat(input[:i] + input[i + 1:], 1)), + (self.in_split[i], self.in_split[i]), dim=1) + output.append(getattr(self, 'conv{}'.format(i))(input[i] * torch.sigmoid(scale) + translation)) + + return torch.cat(output, 1) + + +class MABlock(nn.Module): + ''' Residual block based on MAConv ''' + def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, + split=2, reduction=2): + super(MABlock, self).__init__() + + self.res = nn.Sequential(*[ + MAConv(in_channels, in_channels, kernel_size, stride, padding, bias, split, reduction), + nn.ReLU(inplace=True), + MAConv(in_channels, out_channels, kernel_size, stride, padding, bias, split, reduction), + ]) + + def forward(self, x): + return x + self.res(x) + + +# ------------------------------------------------- +# SFT layer and RRDB block for non-blind RRDB-SFT +# ------------------------------------------------- + +class SFT_Layer(nn.Module): + ''' SFT layer ''' + def __init__(self, nf=64, para=10): + super(SFT_Layer, self).__init__() + self.mul_conv1 = nn.Conv2d(para + nf, 32, kernel_size=3, stride=1, padding=1) + self.mul_leaky = nn.LeakyReLU(0.2) + self.mul_conv2 = nn.Conv2d(32, nf, kernel_size=3, stride=1, padding=1) + + self.add_conv1 = nn.Conv2d(para + nf, 32, kernel_size=3, stride=1, padding=1) + self.add_leaky = nn.LeakyReLU(0.2) + self.add_conv2 = nn.Conv2d(32, nf, kernel_size=3, stride=1, padding=1) + + def forward(self, feature_maps, para_maps): + cat_input = torch.cat((feature_maps, para_maps), dim=1) + mul = torch.sigmoid(self.mul_conv2(self.mul_leaky(self.mul_conv1(cat_input)))) + add = self.add_conv2(self.add_leaky(self.add_conv1(cat_input))) + return feature_maps * mul + add + + +class ResidualDenseBlock_5C(nn.Module): + ''' Residual Dense Block ''' + def __init__(self, nf=64, gc=32, bias=True): + super(ResidualDenseBlock_5C, self).__init__() + # gc: growth channel, i.e. intermediate channels + self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) + self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) + self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) + self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) + self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + return x5 * 0.2 + x + + +class RRDB_SFT(nn.Module): + ''' Residual in Residual Dense Block with SFT layer ''' + + def __init__(self, nf, gc=32, para=15): + super(RRDB_SFT, self).__init__() + self.SFT = SFT_Layer(nf=nf, para=para) + self.RDB1 = ResidualDenseBlock_5C(nf, gc) + self.RDB2 = ResidualDenseBlock_5C(nf, gc) + self.RDB3 = ResidualDenseBlock_5C(nf, gc) + + def forward(self, input): + out = self.SFT(input[0], input[1]) + out = self.RDB1(out) + out = self.RDB2(out) + out = self.RDB3(out) + return [out * 0.2 + input[0], input[1]] + + +# ------------------------------------------------------ +# MANet and its combinations with non-blind SR +# ------------------------------------------------------ + +class MANet(nn.Module): + ''' Network of MANet''' + def __init__(self, in_nc=3, kernel_size=21, nc=[128, 256], nb=1, split=2): + super(MANet, self).__init__() + self.kernel_size = kernel_size + + self.m_head = nn.Conv2d(in_channels=in_nc, out_channels=nc[0], kernel_size=3, padding=1, bias=True) + self.m_down1 = sequential(*[MABlock(nc[0], nc[0], bias=True, split=split) for _ in range(nb)], + nn.Conv2d(in_channels=nc[0], out_channels=nc[1], kernel_size=2, stride=2, padding=0, + bias=True)) + + self.m_body = sequential(*[MABlock(nc[1], nc[1], bias=True, split=split) for _ in range(nb)]) + + self.m_up1 = sequential(nn.ConvTranspose2d(in_channels=nc[1], out_channels=nc[0], + kernel_size=2, stride=2, padding=0, bias=True), + *[MABlock(nc[0], nc[0], bias=True, split=split) for _ in range(nb)]) + self.m_tail = nn.Conv2d(in_channels=nc[0], out_channels=kernel_size ** 2, kernel_size=3, padding=1, bias=True) + + self.softmax = nn.Softmax(1) + + def forward(self, x): + h, w = x.size()[-2:] + paddingBottom = int(np.ceil(h / 8) * 8 - h) + paddingRight = int(np.ceil(w / 8) * 8 - w) + x = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x) + + x1 = self.m_head(x) + x2 = self.m_down1(x1) + x = self.m_body(x2) + x = self.m_up1(x + x2) + x = self.m_tail(x + x1) + + x = x[..., :h, :w] + + x = self.softmax(x) + + return x + + +class MANet_s1(nn.Module): + ''' stage1, train MANet''' + + def __init__(self, in_nc=3, out_nc=3, nf=64, nb=10, gc=32, scale=4, pca_path='./pca_matrix_aniso21_15_x2.pth', + code_length=15, kernel_size=21, manet_nf=256, manet_nb=1, split=2): + super(MANet_s1, self).__init__() + self.scale = scale + self.kernel_size = kernel_size + + self.kernel_estimation = MANet(in_nc=in_nc, kernel_size=kernel_size, nc=[manet_nf, manet_nf * 2], + nb=manet_nb, split=split) + + def forward(self, x, gt_K): + # kernel estimation + kernel = self.kernel_estimation(x) + kernel = F.interpolate(kernel, scale_factor=self.scale, mode='nearest').flatten(2).permute(0, 2, 1) + kernel = kernel.view(-1, kernel.size(1), self.kernel_size, self.kernel_size) + + # no meaning + with torch.no_grad(): + out = F.interpolate(x, scale_factor=self.scale, mode='nearest') + + return out, kernel + + +class MANet_s2(nn.Module): + ''' stage2, train nonblind RRDB-SFT''' + + def __init__(self, in_nc=3, out_nc=3, nf=64, nb=10, gc=32, scale=4, pca_path='./pca_matrix_aniso21_15_x2.pth', + code_length=15, kernel_size=21, manet_nf=256, manet_nb=1, split=2): + super(MANet_s2, self).__init__() + self.scale = scale + self.kernel_size = kernel_size + + self.register_buffer('pca_matrix', torch.load(pca_path).unsqueeze(0).unsqueeze(3).unsqueeze(4)) + RRDB_SFT_block_f = functools.partial(RRDB_SFT, nf=nf, gc=gc, para=code_length) + self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + self.RRDB_trunk = mutil.make_layer(RRDB_SFT_block_f, nb) + self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.upsampler = sequential(nn.Conv2d(nf, out_nc * (scale ** 2), kernel_size=3, stride=1, padding=1, bias=True), + nn.PixelShuffle(scale)) + + def forward(self, x, gt_K): + # GT kernel preprocessing + with torch.no_grad(): + kernel_pca_code = torch.mm(gt_K.flatten(1), self.pca_matrix.squeeze()) \ + .unsqueeze(2).unsqueeze(3).expand(-1, -1, x.shape[2], x.shape[3]) + # no meaning + kernel = gt_K + + # nonblind sr + lr_fea = self.conv_first(x) + fea = self.RRDB_trunk([lr_fea, kernel_pca_code]) + fea = lr_fea + self.trunk_conv(fea[0]) + out = self.upsampler(fea) + + return out, kernel + + +class MANet_s3(nn.Module): + ''' stage3, fine-tune nonblind SR model based on MANet predictions''' + + def __init__(self, in_nc=3, out_nc=3, nf=64, nb=10, gc=32, scale=4, pca_path='./pca_matrix_aniso21_15_x2.pth', + code_length=15, kernel_size=21, manet_nf=256, manet_nb=1, split=2): + super(MANet_s3, self).__init__() + self.scale = scale + self.kernel_size = kernel_size + + self.kernel_estimation = MANet(in_nc=in_nc, kernel_size=kernel_size, nc=[manet_nf, manet_nf * 2], + nb=manet_nb, split=split) + + self.register_buffer('pca_matrix', torch.load(pca_path).unsqueeze(0).unsqueeze(3).unsqueeze(4)) + RRDB_SFT_block_f = functools.partial(RRDB_SFT, nf=nf, gc=gc, para=code_length) + self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + self.RRDB_trunk = mutil.make_layer(RRDB_SFT_block_f, nb) + self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.upsampler = sequential(nn.Conv2d(nf, out_nc * (scale ** 2), kernel_size=3, stride=1, padding=1, bias=True), + nn.PixelShuffle(scale)) + + def forward(self, x, gt_K): + # kernel estimation + with torch.no_grad(): + kernel = self.kernel_estimation(x) + kernel_pca_code = (kernel.unsqueeze(2) * self.pca_matrix).sum(1, keepdim=False) + kernel = F.interpolate(kernel, scale_factor=self.scale, mode='nearest').flatten(2).permute(0, 2, 1) + kernel = kernel.view(-1, kernel.size(1), self.kernel_size, self.kernel_size) + + # nonblind sr + lr_fea = self.conv_first(x) + fea = self.RRDB_trunk([lr_fea, kernel_pca_code]) + fea = lr_fea + self.trunk_conv(fea[0]) + out = self.upsampler(fea) + + return out, kernel + + +if __name__ == '__main__': + model = MANet_s3() + print(model) + + x = torch.randn((2, 3, 100, 100)) + x, k = model(x, 0) + print(x.shape, k.shape) diff --git a/codes/models/modules/discriminator_vgg_arch.py b/codes/models/modules/discriminator_vgg_arch.py new file mode 100644 index 0000000..27dd6a1 --- /dev/null +++ b/codes/models/modules/discriminator_vgg_arch.py @@ -0,0 +1,88 @@ +import torch +import torch.nn as nn +import torchvision + + +class Discriminator_VGG_128(nn.Module): + def __init__(self, in_nc, nf): + super(Discriminator_VGG_128, self).__init__() + # [64, 128, 128] + self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) + self.bn0_1 = nn.BatchNorm2d(nf, affine=True) + # [64, 64, 64] + self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) + self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) + self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) + self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) + # [128, 32, 32] + self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) + self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) + self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) + self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) + # [256, 16, 16] + self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) + self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) + self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) + self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) + # [512, 8, 8] + self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) + self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) + self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) + self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) + + self.linear1 = nn.Linear(512 * 4 * 4, 100) + self.linear2 = nn.Linear(100, 1) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + fea = self.lrelu(self.conv0_0(x)) + fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) + + fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) + fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) + + fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) + fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) + + fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) + fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) + + fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) + fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) + + fea = fea.view(fea.size(0), -1) + fea = self.lrelu(self.linear1(fea)) + out = self.linear2(fea) + return out + + +class VGGFeatureExtractor(nn.Module): + def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, + device=torch.device('cpu')): + super(VGGFeatureExtractor, self).__init__() + self.use_input_norm = use_input_norm + if use_bn: + model = torchvision.models.vgg19_bn(pretrained=True) + else: + model = torchvision.models.vgg19(pretrained=True) + if self.use_input_norm: + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) + # [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1] + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) + # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1] + self.register_buffer('mean', mean) + self.register_buffer('std', std) + self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)]) + # No need to BP to variable + for k, v in self.features.named_parameters(): + v.requires_grad = False + + def forward(self, x): + # Assume input range is [0, 1] + if self.use_input_norm: + x = (x - self.mean) / self.std + output = self.features(x) + return output diff --git a/codes/models/modules/loss.py b/codes/models/modules/loss.py new file mode 100644 index 0000000..d43cb0e --- /dev/null +++ b/codes/models/modules/loss.py @@ -0,0 +1,93 @@ +import torch +import torch.nn as nn + + +class CharbonnierLoss(nn.Module): + """Charbonnier Loss (L1)""" + + def __init__(self, eps=1e-6): + super(CharbonnierLoss, self).__init__() + self.eps = eps + + def forward(self, x, y): + diff = x - y + loss = torch.sum(torch.sqrt(diff * diff + self.eps)) + return loss + + +# Define GAN loss: [vanilla | lsgan | wgan-gp] +class GANLoss(nn.Module): + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): + super(GANLoss, self).__init__() + self.gan_type = gan_type.lower() + self.real_label_val = real_label_val + self.fake_label_val = fake_label_val + + if self.gan_type == 'gan' or self.gan_type == 'ragan': + self.loss = nn.BCEWithLogitsLoss() + elif self.gan_type == 'lsgan': + self.loss = nn.MSELoss() + elif self.gan_type == 'wgan-gp': + + def wgan_loss(input, target): + # target is boolean + return -1 * input.mean() if target else input.mean() + + self.loss = wgan_loss + else: + raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) + + def get_target_label(self, input, target_is_real): + if self.gan_type == 'wgan-gp': + return target_is_real + if target_is_real: + return torch.empty_like(input).fill_(self.real_label_val) + else: + return torch.empty_like(input).fill_(self.fake_label_val) + + def forward(self, input, target_is_real): + target_label = self.get_target_label(input, target_is_real) + loss = self.loss(input, target_label) + return loss + + +class GradientPenaltyLoss(nn.Module): + def __init__(self, device=torch.device('cpu')): + super(GradientPenaltyLoss, self).__init__() + self.register_buffer('grad_outputs', torch.Tensor()) + self.grad_outputs = self.grad_outputs.to(device) + + def get_grad_outputs(self, input): + if self.grad_outputs.size() != input.size(): + self.grad_outputs.resize_(input.size()).fill_(1.0) + return self.grad_outputs + + def forward(self, interp, interp_crit): + grad_outputs = self.get_grad_outputs(interp_crit) + grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, + grad_outputs=grad_outputs, create_graph=True, + retain_graph=True, only_inputs=True)[0] + grad_interp = grad_interp.view(grad_interp.size(0), -1) + grad_interp_norm = grad_interp.norm(2, dim=1) + + loss = ((grad_interp_norm - 1)**2).mean() + return loss + +class ReconstructionLoss(nn.Module): + def __init__(self, losstype='l2', eps=1e-6): + super(ReconstructionLoss, self).__init__() + self.losstype = losstype + self.eps = eps + + def forward(self, x, target): + if self.losstype == 'l2': + return torch.mean(torch.sum((x - target)**2, (1, 2, 3))) + elif self.losstype == 'l1': + diff = x - target + return torch.mean(torch.sum(torch.sqrt(diff * diff + self.eps), (1, 2, 3))) + else: + print("reconstruction loss type error!") + return 0 + + + diff --git a/codes/models/modules/module_util.py b/codes/models/modules/module_util.py new file mode 100644 index 0000000..ca5d7fa --- /dev/null +++ b/codes/models/modules/module_util.py @@ -0,0 +1,79 @@ +import torch +import torch.nn as nn +import torch.nn.init as init +import torch.nn.functional as F + + +def initialize_weights(net_l, scale=1): + if not isinstance(net_l, list): + net_l = [net_l] + for net in net_l: + for m in net.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, a=0, mode='fan_in') + m.weight.data *= scale # for residual block + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, a=0, mode='fan_in') + m.weight.data *= scale + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + init.constant_(m.weight, 1) + init.constant_(m.bias.data, 0.0) + + +def make_layer(block, n_layers): + layers = [] + for _ in range(n_layers): + layers.append(block()) + return nn.Sequential(*layers) + + +class ResidualBlock_noBN(nn.Module): + '''Residual block w/o BN + ---Conv-ReLU-Conv-+- + |________________| + ''' + + def __init__(self, nf=64): + super(ResidualBlock_noBN, self).__init__() + self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + + # initialization + initialize_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = F.relu(self.conv1(x), inplace=True) + out = self.conv2(out) + return identity + out + + +def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): + """Warp an image or feature map with optical flow + Args: + x (Tensor): size (N, C, H, W) + flow (Tensor): size (N, H, W, 2), normal value + interp_mode (str): 'nearest' or 'bilinear' + padding_mode (str): 'zeros' or 'border' or 'reflection' + + Returns: + Tensor: warped image or feature map + """ + assert x.size()[-2:] == flow.size()[1:3] + B, C, H, W = x.size() + # mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + grid = grid.type_as(x) + vgrid = grid + flow + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) + return output diff --git a/codes/models/networks.py b/codes/models/networks.py new file mode 100644 index 0000000..20d3c72 --- /dev/null +++ b/codes/models/networks.py @@ -0,0 +1,66 @@ +import torch +import logging +import models.modules.discriminator_vgg_arch as SRGAN_arch +import models.modules.MANet_arch as MANet_arch + +logger = logging.getLogger('base') + + +#################### +# define network +#################### +#### Generator +def define_G(opt): + opt_net = opt['network_G'] + which_model = opt_net['which_model_G'] + + if which_model == 'MANet_s1': + netG = MANet_arch.MANet_s1(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], + nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], + scale=opt['scale'], pca_path=opt['pca_path'], code_length=opt['code_length'], + kernel_size=opt['kernel_size'], + manet_nf=opt_net['manet_nf'], manet_nb=opt_net['manet_nb'], split=opt_net['split']) + elif which_model == 'MANet_s2': + netG = MANet_arch.MANet_s2(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], + nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], + scale=opt['scale'], pca_path=opt['pca_path'], code_length=opt['code_length'], + kernel_size=opt['kernel_size'], + manet_nf=opt_net['manet_nf'], manet_nb=opt_net['manet_nb'], split=opt_net['split']) + elif which_model == 'MANet_s3': + netG = MANet_arch.MANet_s3(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], + nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], + scale=opt['scale'], pca_path=opt['pca_path'], code_length=opt['code_length'], + kernel_size=opt['kernel_size'], + manet_nf=opt_net['manet_nf'], manet_nb=opt_net['manet_nb'], split=opt_net['split']) + else: + raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) + return netG + + +# functions below are not used + +#### Discriminator +def define_D(opt): + opt_net = opt['network_D'] + which_model = opt_net['which_model_D'] + + if which_model == 'discriminator_vgg_128': + netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf']) + else: + raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) + return netD + + +#### Define Network used for Perceptual Loss +def define_F(opt, use_bn=False): + gpu_ids = opt['gpu_ids'] + device = torch.device('cuda' if gpu_ids else 'cpu') + # PyTorch pretrained VGG19-54, before ReLU. + if use_bn: + feature_layer = 49 + else: + feature_layer = 34 + netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, + use_input_norm=True, device=device) + netF.eval() # No need to train + return netF diff --git a/codes/options/options.py b/codes/options/options.py new file mode 100644 index 0000000..36b665a --- /dev/null +++ b/codes/options/options.py @@ -0,0 +1,120 @@ +import os +import os.path as osp +import logging +import yaml +from utils.util import OrderedYaml +Loader, Dumper = OrderedYaml() + + +def parse(opt_path, gpu_ids_qsub=None, is_train=True): + with open(opt_path, mode='r') as f: + opt = yaml.load(f, Loader=Loader) + # export CUDA_VISIBLE_DEVICES + if gpu_ids_qsub is not None: opt['gpu_ids'] = [int(x) for x in gpu_ids_qsub.split(',')] + gpu_list = ','.join(str(x) for x in opt['gpu_ids']) + os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list + print('export CUDA_VISIBLE_DEVICES=' + gpu_list) + + opt['is_train'] = is_train + if opt['distortion'] == 'sr': + scale = opt['scale'] + + # datasets + for phase, dataset in opt['datasets'].items(): + phase = phase.split('_')[0] + print(dataset) + dataset['phase'] = phase + if opt['distortion'] == 'sr': + dataset['scale'] = scale + is_lmdb = False + if dataset.get('dataroot_GT', None) is not None: + dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT']) + if dataset['dataroot_GT'].endswith('lmdb'): + is_lmdb = True + # if dataset.get('dataroot_GT_bg', None) is not None: + # dataset['dataroot_GT_bg'] = osp.expanduser(dataset['dataroot_GT_bg']) + if dataset.get('dataroot_LQ', None) is not None: + dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ']) + if dataset['dataroot_LQ'].endswith('lmdb'): + is_lmdb = True + dataset['data_type'] = 'lmdb' if is_lmdb else 'img' + if dataset['mode'].endswith('mc'): # for memcached + dataset['data_type'] = 'mc' + dataset['mode'] = dataset['mode'].replace('_mc', '') + + # path + for key, path in opt['path'].items(): + if path and key in opt['path'] and key != 'strict_load': + opt['path'][key] = osp.expanduser(path) + opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) + if is_train: + experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name']) + opt['path']['experiments_root'] = experiments_root + opt['path']['models'] = osp.join(experiments_root, 'models') + opt['path']['training_state'] = osp.join(experiments_root, 'training_state') + opt['path']['log'] = experiments_root + opt['path']['val_images'] = osp.join(experiments_root, 'val_images') + + # change some options for debug mode + if 'debug' in opt['name']: + opt['train']['val_freq'] = 8 + opt['logger']['print_freq'] = 1 + opt['logger']['save_checkpoint_freq'] = 8 + else: # test + results_root = osp.join(opt['path']['root'], 'results', opt['name']) + opt['path']['results_root'] = results_root + opt['path']['log'] = results_root + + # network + if opt['distortion'] == 'sr': + opt['network_G']['scale'] = scale + + return opt + + +def dict2str(opt, indent_l=1): + '''dict to string for logger''' + msg = '' + for k, v in opt.items(): + if isinstance(v, dict): + msg += ' ' * (indent_l * 2) + k + ':[\n' + msg += dict2str(v, indent_l + 1) + msg += ' ' * (indent_l * 2) + ']\n' + else: + msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' + return msg + + +class NoneDict(dict): + def __missing__(self, key): + return None + + +# convert to NoneDict, which return None for missing key. +def dict_to_nonedict(opt): + if isinstance(opt, dict): + new_opt = dict() + for key, sub_opt in opt.items(): + new_opt[key] = dict_to_nonedict(sub_opt) + return NoneDict(**new_opt) + elif isinstance(opt, list): + return [dict_to_nonedict(sub_opt) for sub_opt in opt] + else: + return opt + + +def check_resume(opt, resume_iter): + '''Check resume states and pretrain_model paths''' + logger = logging.getLogger('base') + if opt['path']['resume_state']: + if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get( + 'pretrain_model_D', None) is not None: + logger.warning('pretrain_model path will be ignored when resuming training.') + + opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'], + '{}_G.pth'.format(resume_iter)) + logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G']) + if 'gan' in opt['model']: + opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'], + '{}_D.pth'.format(resume_iter)) + logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D']) diff --git a/codes/options/test/prepare_testset.yml b/codes/options/test/prepare_testset.yml new file mode 100644 index 0000000..c015f7f --- /dev/null +++ b/codes/options/test/prepare_testset.yml @@ -0,0 +1,52 @@ +name: 001_MANet_prepare_dataset +suffix: ~ +model: blind +distortion: sr +scale: ~ +gpu_ids: [0] +kernel_size: 21 +code_length: 15 +sig_min: ~ +sig_max: ~ +sig: ~ +sig1: ~ +sig2: ~ +theta: ~ +rate_iso: 0 # 1 for iso, 0 for aniso +sv_mode: ~ +test_noise: ~ +noise: ~ + +datasets: + test1: + name: Set5 + mode: GT + dataroot_GT: ../datasets/Set5/HR + dataroot_LQ: ~ + test2: + name: Set14 + mode: GT + dataroot_GT: ../datasets/Set14/HR + dataroot_LQ: ~ + test3: + name: BSD100 + mode: GT + dataroot_GT: ../datasets/BSD100/HR + dataroot_LQ: ~ + test4: + name: Urban100 + mode: GT + dataroot_GT: ../datasets/Urban100/HR + dataroot_LQ: ~ + +network_G: + which_model_G: ~ + in_nc: ~ + out_nc: ~ + nf: ~ + nb: ~ + upscale: 0 +# +path: + strict_load: true + pretrain_model_G: ~ diff --git a/codes/options/test/test_stage1.yml b/codes/options/test/test_stage1.yml new file mode 100644 index 0000000..8397422 --- /dev/null +++ b/codes/options/test/test_stage1.yml @@ -0,0 +1,64 @@ +name: 001_MANet_aniso_x4_test_stage1 +suffix: ~ +model: blind +distortion: sr +scale: 4 +crop_border: ~ +gpu_ids: [0] +kernel_size: 21 +code_length: 15 +sig_min: 0 +sig_max: 0 +sig: 3.0 +sig1: 6 +sig2: 1 +theta: 0.7853981633974483 +rate_iso: 0 # 1 for iso, 0 for aniso +sv_mode: 0 # 0 for spatially invariant kernel, 1-5 for spatially variant kernel types as in Table 2 +test_noise: False +noise: 15 +test_jpeg: False +jpeg: 50 +cal_lr_psnr: False # calculate lr pixel consumes huge memory + + +datasets: + # example1: HR input only (generating LR on-the-fly) + test_1: + name: toy_dataset1 + mode: GT + dataroot_GT: ../datasets/toy_dataset/HR_si + dataroot_LQ: ~ + + # example2: LR input only (no HR) +# test_2: +# name: toy_dataset2 +# mode: LQ +# dataroot_GT: ~ +# dataroot_LQ: ../datasets/toy_dataset/LR_mode0_noise0 +# scale: 4 +# kernel_size: 21 + + # example3: HR-LR pairs +# test_3: +# name: toy_dataset3 +# mode: GTLQ +# dataroot_GT: ../datasets/toy_dataset/HR_si +# dataroot_LQ: ../datasets/toy_dataset/LR_mode0_noise0 + + +network_G: + which_model_G: MANet_s1 + in_nc: 3 + out_nc: ~ + nf: ~ + nb: ~ + gc: ~ + manet_nf: 128 + manet_nb: 1 + split: 2 + + +path: + strict_load: true + pretrain_model_K: ../experiments/pretrained_models/stage1_MANet_x4.pth diff --git a/codes/options/test/test_stage2.yml b/codes/options/test/test_stage2.yml new file mode 100644 index 0000000..2d917d0 --- /dev/null +++ b/codes/options/test/test_stage2.yml @@ -0,0 +1,47 @@ +name: 001_MANet_aniso_x4_test_stage2 +suffix: ~ +model: blind +distortion: sr +scale: 4 +crop_border: ~ +gpu_ids: [0] +kernel_size: 21 +code_length: 15 +sig_min: 0 +sig_max: 0 +sig: 3.0 +sig1: 6 +sig2: 1 +theta: 0.7853981633974483 +rate_iso: 0 # 1 for iso, 0 for aniso +sv_mode: 0 # 0 for spatially invariant kernel, 1-5 for sv kernel as in Table 3 +test_noise: False +noise: 15 +test_jpeg: False +jpeg: 50 +pca_path: ./pca_matrix_aniso21_15_x4.pth + + +datasets: + # HR input only (generating LR on-the-fly) + test_1: + name: toy_dataset1 + mode: GT + dataroot_GT: ../datasets/toy_dataset/HR_si + dataroot_LQ: ~ + +network_G: + which_model_G: MANet_s2 + in_nc: 3 + out_nc: 3 + nf: 64 + nb: 10 + gc: 32 + manet_nf: ~ + manet_nb: ~ + split: ~ + + +path: + strict_load: true + pretrain_model_G: ../experiments/pretrained_models/stage2_RRDB_x4.pth diff --git a/codes/options/test/test_stage3.yml b/codes/options/test/test_stage3.yml new file mode 100644 index 0000000..304d10c --- /dev/null +++ b/codes/options/test/test_stage3.yml @@ -0,0 +1,68 @@ +name: 001_MANet_aniso_x4_test_stage3 +suffix: ~ +model: blind +distortion: sr +scale: 4 +crop_border: ~ +gpu_ids: [0] +kernel_size: 21 +code_length: 15 +sig_min: 0 +sig_max: 0 +sig: 3.0 +sig1: 6 +sig2: 1 +theta: 0.7853981633974483 +rate_iso: 0 # 1 for iso, 0 for aniso +sv_mode: 0 # 0 for spatially invariant kernel, 1-5 for spatially variant kernel types as in Table 2 +test_noise: False +noise: 15 +test_jpeg: False +jpeg: 50 +pca_path: ./pca_matrix_aniso21_15_x4.pth +cal_lr_psnr: False # calculate lr pixel consumes huge memory + + +datasets: + # example1: HR input only (generating LR on-the-fly) + test_1: + name: toy_dataset1 + mode: GT + dataroot_GT: ../datasets/toy_dataset/HR_si + dataroot_LQ: ~ + + # example2: LR input only (no HR) +# test_2: +# name: toy_dataset2 +# mode: LQ +# dataroot_GT: ~ +# dataroot_LQ: ../datasets/toy_dataset/LR_mode0_noise0 +# scale: 4 +# kernel_size: 21 + + # example3: HR-LR pairs +# test_3: +# name: toy_dataset3 +# mode: GTLQ +# dataroot_GT: ../datasets/toy_dataset/HR_si +# dataroot_LQ: ../datasets/toy_dataset/LR_mode0_noise0 + + + +network_G: + which_model_G: MANet_s3 + in_nc: 3 + out_nc: 3 + nf: 64 + nb: 10 + gc: 32 + manet_nf: 128 + manet_nb: 1 + split: 2 + + +path: + strict_load: true + pretrain_model_G: ../experiments/pretrained_models/stage3_MANet+RRDB_x4.pth + + diff --git a/codes/options/train/train_stage1.yml b/codes/options/train/train_stage1.yml new file mode 100644 index 0000000..6f3511c --- /dev/null +++ b/codes/options/train/train_stage1.yml @@ -0,0 +1,98 @@ +#### general settings +name: 001_MANet_aniso_x4_DIV2K_40_stage1 +use_tb_logger: true +model: blind +distortion: sr +scale: 4 +gpu_ids: [0] +kernel_size: 21 +code_length: 15 +# train +sig_min: 0.7 # 0.7, 0.525, 0.35 for x4, x3, x2 +sig_max: 10.0 # 10, 7.5, 5 for x4, x3, x2 +train_noise: False +noise_high: 15 +train_jpeg: False +jpeg_low: 70 +# validation +sig: 1.6 +sig1: 6 # 6, 5, 4 for x4, x3, x2 +sig2: 1 +theta: 0 +rate_iso: 0 # 1 for iso, 0 for aniso +test_noise: False +noise: 15 +test_jpeg: False +jpeg: 70 +pca_path: ./pca_matrix_aniso21_15_x4.pth +cal_lr_psnr: False # calculate lr psnr consumes huge memory + + +#### datasets +datasets: + train: + name: DIV2K + mode: GT + dataroot_GT: ../datasets/DIV2K/HR + dataroot_LQ: ~ + + use_shuffle: true + n_workers: 8 + batch_size: 16 + GT_size: 192 + LR_size: ~ + use_flip: true + use_rot: true + color: RGB + val: + name: Set5 + mode: GT + dataroot_GT: ../datasets/Set5/HR + dataroot_LQ: ~ + + +#### network structures +network_G: + which_model_G: MANet_s1 + in_nc: 3 + out_nc: ~ + nf: ~ + nb: ~ + gc: ~ + manet_nf: 128 + manet_nb: 1 + split: 2 + + +#### path +path: + pretrain_model_G: ~ + strict_load: true + resume_state: ~ #../experiments/001_MANet_aniso_x4_DIV2K_40_stage1/training_state/5000.state + + +#### training settings: learning rate scheme, loss +train: + lr_G: !!float 2e-4 + lr_scheme: MultiStepLR + beta1: 0.9 + beta2: 0.999 + niter: 300000 + warmup_iter: -1 + lr_steps: [100000, 150000, 200000, 250000] + lr_gamma: 0.5 + restarts: ~ + restart_weights: ~ + eta_min: !!float 1e-7 + + kernel_criterion: l1 + kernel_weight: 1.0 + + manual_seed: 0 + val_freq: !!float 5e3 + + +#### logger +logger: + print_freq: 200 + save_checkpoint_freq: !!float 5e3 diff --git a/codes/options/train/train_stage2.yml b/codes/options/train/train_stage2.yml new file mode 100644 index 0000000..dd73776 --- /dev/null +++ b/codes/options/train/train_stage2.yml @@ -0,0 +1,95 @@ +#### general settings +name: 001_MANet_aniso_x4_DIV2K+Flickr2K_stage2 +use_tb_logger: true +model: blind +distortion: sr +scale: 4 +gpu_ids: [0] +kernel_size: 21 +code_length: 15 +# train +sig_min: 0.7 # 0.7, 0.525, 0.35 for x4, x3, x2 +sig_max: 10.0 # 10, 7.5, 5 for x4, x3, x2 +train_noise: False +noise_high: 15 +train_jpeg: False +jpeg_low: 70 +# validation +sig: 1.6 +sig1: 6 # 6, 5, 4 for x4, x3, x2 +sig2: 1 +theta: 0 +rate_iso: 0 # 1 for iso, 0 for aniso +test_noise: False +noise: 15 +test_jpeg: False +jpeg: 70 +pca_path: ./pca_matrix_aniso21_15_x4.pth + + +#### datasets +datasets: + train: + name: DIV2K+Flickr2K + mode: GT + dataroot_GT: ../datasets/DIV2K+Flickr2K/HR + dataroot_LQ: ~ + + use_shuffle: true + n_workers: 8 + batch_size: 16 + GT_size: 192 + LR_size: ~ + use_flip: true + use_rot: true + color: RGB + val: + name: Set5 + mode: GT + dataroot_GT: ../datasets/Set5/HR + dataroot_LQ: ~ + + +#### network structures +network_G: + which_model_G: MANet_s2 + in_nc: 3 + out_nc: 3 + nf: 64 + nb: 10 + gc: 32 + manet_nf: ~ + manet_nb: ~ + split: ~ + +#### path +path: + pretrain_model_G: ~ + strict_load: true + resume_state: ~ #../experiments/001_MANet_aniso_x4_DIV2K_40_stage2/training_state/5000.state + + +#### training settings: learning rate scheme, loss +train: + lr_G: !!float 2e-4 + lr_scheme: CosineAnnealingLR_Restart + beta1: 0.9 + beta2: 0.99 + niter: 480000 + warmup_iter: -1 + T_period: [120000, 120000, 120000, 120000, 120000] + restarts: [120000, 240000, 360000, 480000] + restart_weights: [1, 1, 1, 1] + eta_min: !!float 1e-7 + + pixel_criterion: l1 + pixel_weight: 1.0 + + manual_seed: 0 + val_freq: !!float 5e3 + + +#### logger +logger: + print_freq: 200 + save_checkpoint_freq: !!float 5e3 diff --git a/codes/options/train/train_stage3.yml b/codes/options/train/train_stage3.yml new file mode 100644 index 0000000..62811ba --- /dev/null +++ b/codes/options/train/train_stage3.yml @@ -0,0 +1,98 @@ +#### general settings +name: 001_MANet_aniso_x4_DIV2K+Flickr2K_stage3 +use_tb_logger: true +model: blind +distortion: sr +scale: 4 +gpu_ids: [0] +kernel_size: 21 +code_length: 15 +# train +sig_min: 0.7 # 0.7, 0.525, 0.35 for x4, x3, x2 +sig_max: 10.0 # 10, 7.5, 5 for x4, x3, x2 +train_noise: False +noise_high: 15 +train_jpeg: False +jpeg_low: 70 +# validation +sig: 1.6 +sig1: 6 # 6, 5, 4 for x4, x3, x2 +sig2: 1 +theta: 0 +rate_iso: 0 # 1 for iso, 0 for aniso +test_noise: False +noise: 15 +test_jpeg: False +jpeg: 70 +pca_path: ./pca_matrix_aniso21_15_x4.pth +cal_lr_psnr: False # calculate lr psnr consumes huge memory + + +#### datasets +datasets: + train: + name: DIV2K+Flickr2K + mode: GT + dataroot_GT: ../datasets/DIV2K+Flickr2K/HR + dataroot_LQ: ~ + + use_shuffle: true + n_workers: 8 + batch_size: 16 + GT_size: 192 + LR_size: ~ + use_flip: true + use_rot: true + color: RGB + val: + name: Set5 + mode: GT + dataroot_GT: ../datasets/Set5/HR + dataroot_LQ: ~ + +#### network structures +network_G: + which_model_G: MANet_s3 + in_nc: 3 + out_nc: 3 + nf: 64 + nb: 10 + gc: 32 + manet_nf: 128 + manet_nb: 1 + split: 2 + + +#### path +path: + pretrain_model_K: ../pretrained_models/stage1_MANet_x4.pth + pretrain_model_G: ../pretrained_models/stage2_RRDB_x4.pth + strict_load: false + resume_state: ~ #../experiments/001_MANet_aniso_x4_DIV2K_40_stage3/training_state/5000.state + + +#### training settings: learning rate scheme, loss +train: + lr_G: !!float 5e-5 + lr_scheme: MultiStepLR + beta1: 0.9 + beta2: 0.999 + niter: 250000 + warmup_iter: -1 + lr_steps: [50000, 100000, 150000, 200000, 250000] + lr_gamma: 0.5 + restarts: ~ + restart_weights: ~ + eta_min: !!float 1e-7 + + pixel_criterion: l1 + pixel_weight: 1.0 + + manual_seed: 0 + val_freq: !!float 5e3 + + +#### logger +logger: + print_freq: 200 + save_checkpoint_freq: !!float 5e3 diff --git a/codes/pca_matrix_aniso21_15_x2.pth b/codes/pca_matrix_aniso21_15_x2.pth new file mode 100644 index 0000000..938c103 Binary files /dev/null and b/codes/pca_matrix_aniso21_15_x2.pth differ diff --git a/codes/pca_matrix_aniso21_15_x3.pth b/codes/pca_matrix_aniso21_15_x3.pth new file mode 100644 index 0000000..16b3a8b Binary files /dev/null and b/codes/pca_matrix_aniso21_15_x3.pth differ diff --git a/codes/pca_matrix_aniso21_15_x4.pth b/codes/pca_matrix_aniso21_15_x4.pth new file mode 100644 index 0000000..7e9911e Binary files /dev/null and b/codes/pca_matrix_aniso21_15_x4.pth differ diff --git a/codes/prepare_testset.py b/codes/prepare_testset.py new file mode 100644 index 0000000..37d1d9b --- /dev/null +++ b/codes/prepare_testset.py @@ -0,0 +1,117 @@ +import os.path +import logging +import argparse +import numpy as np +import torch +import sys +import options.options as option +from data import create_dataset, create_dataloader +import utils.util as util + + +def generate_dataset(opt, test_loader, save_dir_HR, save_dir_LR, device_id): + prepro = util.SRMDPreprocessing(opt['scale'], random=False, l=opt['kernel_size'], add_noise=opt['test_noise'], + noise_high=opt['noise'] / 255., rate_cln=-1, + device=torch.device('cuda:{}'.format(device_id)), sig=opt['sig'], sig1=opt['sig1'], + sig2=opt['sig2'], theta=opt['theta'], + sig_min=opt['sig_min'], sig_max=opt['sig_max'], rate_iso=opt['rate_iso'], + is_training=False, sv_mode=opt['sv_mode']) + + for test_data in test_loader: + img_name = os.path.splitext(os.path.basename(test_data['GT_path'][0]))[0] + test_data['GT'] = test_data['GT'].to(torch.device('cuda:{}'.format(device_id))) + GT_img = util.tensor2img(test_data['GT']) + LR_img, LR_n_img, ker_map, kernel = prepro(test_data['GT'], kernel=True) + LR_n_img = util.tensor2img(LR_n_img) # uint8 + + # save images + if opt['sv_mode'] == 0: + img_name += '_{:.1f}_{:.1f}_{:.1f}.png'.format(opt['sig1'], opt['sig2'], opt['theta']) + print('processing {:>30s} for scale {}, SI mode with kernel: '.format(img_name, opt['scale']), ker_map.cpu()) + else: + img_name += '.png' + print('processing {:>30s} for scale {}, SV mode {}'.format(img_name, opt['scale'], opt['sv_mode'])) + + util.save_img(GT_img, os.path.join(save_dir_HR, img_name)) + util.save_img(LR_n_img, os.path.join(save_dir_LR, img_name)) + + +def main(): + #### options + parser = argparse.ArgumentParser() + parser.add_argument('--opt', type=str, default='options/test/prepare_testset.yml', + help='Path to options YMAL file.') + opt = option.parse(parser.parse_args().opt, is_train=False) + opt = option.dict_to_nonedict(opt) + + #### mkdir and logger + util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' + and 'pretrain_model' not in key and 'resume' not in key)) + util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, + screen=True, tofile=True) + logger = logging.getLogger('base') + logger.info(option.dict2str(opt)) + device_id = torch.cuda.current_device() + + # set random seed + util.set_random_seed(0) + + for scale in [2, 3, 4]: + opt['scale'] = scale + + #### Create test dataset and dataloader + test_loaders = [] + for phase, dataset_opt in sorted(opt['datasets'].items()): + dataset_opt['scale'] = scale + test_set = create_dataset(dataset_opt) + test_loader = create_dataloader(test_set, dataset_opt) + logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) + test_loaders.append(test_loader) + + for test_loader in test_loaders: + test_set_name = test_loader.dataset.opt['name'] + logger.info('\nGenerating [{:s}]...'.format(test_set_name)) + + for noise in [0, 15]: + opt['noise'] = noise + if noise > 0: opt['test_noise'] = True + + if opt['noise'] == 15 and opt['scale'] != 4: + continue + + sv_modes = [0, 1, 2, 3, 4, 5] if test_set_name == 'BSD100' else [0] + for sv_mode in sv_modes: + opt['sv_mode'] = sv_mode + + save_dir = test_loader.dataset.opt['dataroot_GT'].replace('/HR', '_x{}'.format(opt['scale'])) + save_dir_HR = os.path.join(save_dir, 'HR_si') if opt['sv_mode'] == 0 else os.path.join(save_dir, + 'HR_sv') + save_dir_LR = os.path.join(save_dir, 'LR_mode{}_noise{}'.format(opt['sv_mode'], opt['noise'])) + util.mkdir(save_dir_HR) + util.mkdir(save_dir_LR) + + # spatial-invariant + if opt['sv_mode'] == 0: + for sig1 in [1, 1 + opt['scale'], 1 + 2 * opt['scale']]: + opt['sig1'] = sig1 + for sig2 in range(1, 1 + sig1, opt['scale']): + opt['sig2'] = sig2 + for theta in [0, np.pi / 4]: + opt['theta'] = theta + + if sig1 == sig2 and theta > 0: + continue + + generate_dataset(opt, test_loader, save_dir_HR, save_dir_LR, device_id) + + # spatial-variant + else: + generate_dataset(opt, test_loader, save_dir_HR, save_dir_LR, device_id) + + + print('\n \nNote: \nFor spatially invariant (SI) SR, HR and LR images are in `HR_si` and `LR_mode0_noise0`, respectively. \n' + 'For spatially variant (SV) SR, HR and LR images are organized as `HR_sv` and `LR_mode1_noise0`, respectively.\n\n') + +if __name__ == '__main__': + main() + sys.exit(0) diff --git a/codes/test.py b/codes/test.py new file mode 100644 index 0000000..58193eb --- /dev/null +++ b/codes/test.py @@ -0,0 +1,229 @@ +import os.path +import logging +import time +import argparse +from collections import OrderedDict +import numpy as np +import torch +import options.options as option +import utils.util as util +from data.util import bgr2ycbcr +from data import create_dataset, create_dataloader +from models import create_model + +#### options +parser = argparse.ArgumentParser() +parser.add_argument('--opt', type=str, default='options/test/test_stage1.yml', help='Path to options YMAL file.') +parser.add_argument('--save_kernel', action='store_true', default=False, help='Save Kernel Esimtation.') +args = parser.parse_args() +opt = option.parse(args.opt, is_train=False) +opt = option.dict_to_nonedict(opt) +device_id = torch.cuda.current_device() + +#### mkdir and logger +util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' + and 'pretrain_model' not in key and 'resume' not in key)) +util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, + screen=True, tofile=True) +logger = logging.getLogger('base') +logger.info(option.dict2str(opt)) + +# set random seed +util.set_random_seed(0) + +#### Create test dataset and dataloader +test_loaders = [] +for phase, dataset_opt in sorted(opt['datasets'].items()): + test_set = create_dataset(dataset_opt) + test_loader = create_dataloader(test_set, dataset_opt) + logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) + test_loaders.append(test_loader) + +# load pretrained model by default +model = create_model(opt) + +for test_loader in test_loaders: + test_set_name = test_loader.dataset.opt['name'] # path opt[''] + logger.info('\nTesting [{:s}]...'.format(test_set_name)) + test_start_time = time.time() + dataset_dir = os.path.join(opt['path']['results_root'], test_set_name) + util.mkdir(dataset_dir) + + test_results = OrderedDict() + test_results['psnr'] = [] + test_results['ssim'] = [] + test_results['psnr_y'] = [] + test_results['ssim_y'] = [] + test_results['psnr_k'] = [] + test_results['mae_n'] = [] + test_results['lr_psnr_y'] = [] + test_results['lr_ssim_y'] = [] + + #### preprocessing for LR_img and kernel map + prepro = util.SRMDPreprocessing(opt['scale'], random=False, l=opt['kernel_size'], add_noise=opt['test_noise'], + noise_high=opt['noise'] / 255., add_jpeg=opt['test_jpeg'], jpeg_low=opt['jpeg'], + rate_cln=-1, device=torch.device('cuda:{}'.format(device_id)), sig=opt['sig'], + sig1=opt['sig1'], sig2=opt['sig2'], theta=opt['theta'], + sig_min=opt['sig_min'], sig_max=opt['sig_max'], rate_iso=opt['rate_iso'], + is_training=False, sv_mode=opt['sv_mode']) + + for test_data in test_loader: + real_image = True if test_loader.dataset.opt['dataroot_GT'] is None else False + generate_online = True if test_loader.dataset.opt['dataroot_GT'] is not None and test_loader.dataset.opt[ + 'dataroot_LQ'] is None else False + img_path = test_data['LQ_path'][0] if real_image else test_data['GT_path'][0] + img_name = os.path.splitext(os.path.basename(img_path))[0] + + if real_image: + LR_img, LR_n_img, ker_map, kernel = test_data['LQ'], test_data['LQ'], torch.ones(1, 1, 1), \ + torch.ones(1, 1, opt['kernel_size'], opt['kernel_size']) + elif generate_online: + test_data['GT'] = test_data['GT'].to(torch.device('cuda:{}'.format(device_id))) + LR_img, LR_n_img, ker_map, kernel = prepro(test_data['GT'], kernel=True) + print(ker_map.cpu()) + else: + # note that it is not sutible for non-blind testing! because kernel is zero by default + LR_img, LR_n_img, ker_map, kernel = test_data['LQ'], test_data['LQ'], torch.ones(1, 1, 1), \ + torch.ones(1, 1, opt['kernel_size'], opt['kernel_size']) + + model.feed_data(test_data, LR_img, LR_n_img, ker_map, kernel) + model.test() + + visuals = model.get_current_visuals() + + sr_img = util.tensor2img(visuals['SR']) # uint8 + + # deal with the image margins for real images + if test_loader.dataset.opt['mode'] == 'LQ': + if opt['scale'] == 4: + real_crop = 3 + elif opt['scale'] == 2: + real_crop = 6 + elif opt['scale'] == 1: + real_crop = 11 + assert real_crop * opt['scale'] * 2 > opt['kernel_size'] + sr_img = sr_img[real_crop * opt['scale']:-real_crop * opt['scale'], + real_crop * opt['scale']:-real_crop * opt['scale'], :] + + # save images + suffix = opt['suffix'] + if suffix: + save_img_path = os.path.join(dataset_dir, img_name + suffix + '.png') + save_ker_path = os.path.join(dataset_dir, '0kernel_{:s}{}.png'.format(img_name, suffix)) + save_ker_SV_path = os.path.join(dataset_dir, 'npz', img_name + suffix + '.npz') + else: + save_img_path = os.path.join(dataset_dir, img_name + '.png') + save_ker_path = os.path.join(dataset_dir, '0kernel_{:s}.png'.format(img_name)) + save_ker_SV_path = os.path.join(dataset_dir, 'npz', img_name + '.npz') + util.save_img(sr_img, save_img_path) + if args.save_kernel: + os.makedirs(os.path.join(dataset_dir, 'npz'), exist_ok=True) + + # choose a kernel to visualize from SV kernels + if len(visuals['KE'].shape) > 2: + est_ker = util.tensor2img(visuals['KE'][300, :, :], np.float32) + est_ker_sv = visuals['KE'].float().cpu().numpy().astype(np.float32) + else: + est_ker = util.tensor2img(visuals['KE'], np.float32) + est_ker_sv = None + + if real_image: + util.plot_kernel(est_ker, save_ker_path) + if args.save_kernel and est_ker_sv is not None: + np.savez(save_ker_SV_path, sr_img=sr_img, est_ker_sv=est_ker_sv, gt_ker=0) + + # calculate PSNR for LR + gt_img_lr = util.tensor2img(visuals['LQ']) + sr_img_lr = util.tensor2img(visuals['LQE']) + gt_img_lr = gt_img_lr / 255. + sr_img_lr = sr_img_lr / 255. + + crop_border = opt['crop_border'] if opt['crop_border'] else opt['scale'] + if gt_img_lr.shape[2] == 3: # RGB image + sr_img_lr_y = bgr2ycbcr(sr_img_lr, only_y=True) + gt_img_lr_y = bgr2ycbcr(gt_img_lr, only_y=True) + if crop_border == 0: + cropped_sr_img_lr_y = sr_img_lr_y + cropped_gt_img_lr_y = gt_img_lr_y + else: + cropped_sr_img_lr_y = sr_img_lr_y[crop_border:-crop_border, crop_border:-crop_border] + cropped_gt_img_lr_y = gt_img_lr_y[crop_border:-crop_border, crop_border:-crop_border] + lr_psnr_y = util.calculate_psnr(cropped_sr_img_lr_y * 255, cropped_gt_img_lr_y * 255) + lr_ssim_y = util.calculate_ssim(cropped_sr_img_lr_y * 255, cropped_gt_img_lr_y * 255) + test_results['lr_psnr_y'].append(lr_psnr_y) + test_results['lr_ssim_y'].append(lr_ssim_y) + + # calculate PSNR and SSIM + if not real_image: + gt_ker = util.tensor2img(visuals['K'], np.float32) + + # for debug and visualization + if not gt_ker.shape == est_ker.shape: + gt_ker = est_ker + + util.plot_kernel(est_ker, save_ker_path, gt_ker) + if args.save_kernel and est_ker_sv is not None: + np.savez(save_ker_SV_path, sr_img=sr_img, est_ker_sv=est_ker_sv, gt_ker=gt_ker) + psnr_k = util.calculate_kernel_psnr(est_ker, gt_ker) + test_results['psnr_k'].append(psnr_k) + + gt_img = util.tensor2img(visuals['GT']) + gt_img = gt_img / 255. + sr_img = sr_img / 255. + + crop_border = opt['crop_border'] if opt['crop_border'] else opt['scale'] + if crop_border == 0: + cropped_sr_img = sr_img + cropped_gt_img = gt_img + else: + cropped_sr_img = sr_img[crop_border:-crop_border, crop_border:-crop_border, :] + cropped_gt_img = gt_img[crop_border:-crop_border, crop_border:-crop_border, :] + + psnr = 0 # util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255) + ssim = 0 # util.calculate_ssim(cropped_sr_img * 255, cropped_gt_img * 255) + test_results['psnr'].append(psnr) + test_results['ssim'].append(ssim) + + if gt_img.shape[2] == 3: # RGB image + sr_img_y = bgr2ycbcr(sr_img, only_y=True) + gt_img_y = bgr2ycbcr(gt_img, only_y=True) + if crop_border == 0: + cropped_sr_img_y = sr_img_y + cropped_gt_img_y = gt_img_y + else: + cropped_sr_img_y = sr_img_y[crop_border:-crop_border, crop_border:-crop_border] + cropped_gt_img_y = gt_img_y[crop_border:-crop_border, crop_border:-crop_border] + psnr_y = util.calculate_psnr(cropped_sr_img_y * 255, cropped_gt_img_y * 255) + ssim_y = util.calculate_ssim(cropped_sr_img_y * 255, cropped_gt_img_y * 255) + test_results['psnr_y'].append(psnr_y) + test_results['ssim_y'].append(ssim_y) + logger.info( + '{:20s} - PSNR/SSIM: {:.2f}/{:.4f}; PSNR_Y/SSIM_Y: {:.2f}/{:.4f}; LR_PSNR_Y/LR_SSIM_Y: {:.2f}/{' + ':.4f}; PSNR_K: {:.2f} dB.'.format( + img_name, psnr, ssim, psnr_y, ssim_y, lr_psnr_y, lr_ssim_y, psnr_k)) + else: + logger.info( + '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_K: {:.6f} dB.'.format(img_name, psnr, ssim, psnr_k)) + else: + logger.info('{:20s} - LR_PSNR_Y/LR_SSIM_Y: {:.2f}/{:.4f}'.format(img_name, lr_psnr_y, lr_ssim_y)) + + ave_lr_psnr_y = sum(test_results['lr_psnr_y']) / len(test_results['lr_psnr_y']) + ave_lr_ssim_y = sum(test_results['lr_ssim_y']) / len(test_results['lr_ssim_y']) + if not real_image: # metrics + # Average PSNR/SSIM results + ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) + ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) + + if test_results['psnr_y'] and test_results['ssim_y']: + ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y']) + ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y']) + + ave_psnr_k = sum(test_results['psnr_k']) / len(test_results['psnr_k']) + logger.info( + '----{} ({} images), average PSNR_Y/SSIM_Y: {:.2f}/{:.4f}, LR_PSNR_Y/LR_SSIM_Y: {:.2f}/{:.4f}, ' + 'kernel PSNR: {:.2f}\n'. + format(test_set_name, len(test_results['psnr_y']), ave_psnr_y, ave_ssim_y, ave_lr_psnr_y, + ave_lr_ssim_y, ave_psnr_k)) + + else: + logger.info('LR PSNR_K/LR_SSIM_Y: {:.2f}/{:.4f}\n'.format(ave_lr_psnr_y, ave_lr_ssim_y)) diff --git a/codes/train.py b/codes/train.py new file mode 100644 index 0000000..9c2d577 --- /dev/null +++ b/codes/train.py @@ -0,0 +1,347 @@ +import os +import math +import argparse +import random +import logging +import numpy as np +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from data.data_sampler import DistIterSampler +from data.util import bgr2ycbcr + +import options.options as option +from utils import util +from data import create_dataloader, create_dataset +from models import create_model + +import socket +import getpass + +def init_dist(backend='nccl', **kwargs): + ''' initialization for distributed training''' + # if mp.get_start_method(allow_none=True) is None: + if mp.get_start_method(allow_none=True) != 'spawn': # Return the name of start method used for starting processes + mp.set_start_method('spawn', force=True) ##'spawn' is the default on Windows + rank = int(os.environ['RANK']) # system env process ranks + num_gpus = torch.cuda.device_count() # Returns the number of GPUs available + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) # Initializes the default distributed process group + + +def main(): + ###### MANet train ###### + #### setup options + parser = argparse.ArgumentParser() + parser.add_argument('--opt', type=str, default='options/train/train_stage1.yml', + help='Path to option YMAL file of MANet.') + parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument('--gpu_ids_qsub', type=str, default=None) + parser.add_argument('--slurm_job_id', type=str, default=0) + args = parser.parse_args() + opt = option.parse(args.opt, args.gpu_ids_qsub, is_train=True) + device_id = torch.cuda.current_device() + + # convert to NoneDict, which returns None for missing keys + opt = option.dict_to_nonedict(opt) + + #### random seed + seed = opt['train']['manual_seed'] + if seed is None: + seed = random.randint(1, 10000) + util.set_random_seed(seed) + + pca_matrix_path = opt['pca_path'] + if pca_matrix_path is not None: + if not os.path.exists(pca_matrix_path): + # create PCA matrix of enough kernel and save it, to ensure all kernel have same corresponding kernel maps + batch_ker, _ = util.random_batch_kernel(batch=150000, l=opt['kernel_size'], + sig_min=opt['sig_min'], sig_max=opt['sig_max'], + rate_iso=opt['rate_iso'], + scale=opt['scale'], tensor=False) + print('batch kernel shape: {}'.format(batch_ker.shape)) + b = np.size(batch_ker, 0) + batch_ker = batch_ker.reshape((b, -1)) + pca_matrix = util.PCA(batch_ker, k=opt['code_length']).float() + print('PCA matrix shape: {}'.format(pca_matrix.shape)) + torch.save(pca_matrix, pca_matrix_path) + print('Save PCA matrix at: {}'.format(pca_matrix_path)) + + #### distributed training settings + if args.launcher == 'none': # disabled distributed training + opt['dist'] = False + rank = -1 + print('Disabled distributed training.') + else: + opt['dist'] = True + init_dist() + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + #### loading resume state if exists + if opt['path'].get('resume_state', None): + # distributed resuming: all load into default GPU + device_id = torch.cuda.current_device() + resume_state = torch.load(opt['path']['resume_state'], + map_location=lambda storage, loc: storage.cuda(device_id)) + option.check_resume(opt, resume_state['iter']) # check resume options + else: + resume_state = None + + #### mkdir and loggers + if rank <= 0: + if resume_state is None: + util.mkdir_and_rename( + opt['path']['experiments_root']) # rename experiment folder if exists + util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' + and 'pretrain_model' not in key and 'resume' not in key)) + + # config loggers. Before it, the log will not work + util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, + screen=True, tofile=True) + util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO, + screen=True, tofile=True) + logger = logging.getLogger('base') + logger.info('{}@{}, GPU {}, Job_id {}, Job path {}'.format(getpass.getuser(), socket.gethostname(), + opt['gpu_ids'], args.slurm_job_id, os.getcwd())) + logger.info(option.dict2str(opt)) + # tensorboard logger + if opt['use_tb_logger'] and 'debug' not in opt['name']: + version = float(torch.__version__[0:3]) + if version >= 1.1: # PyTorch 1.1 + from torch.utils.tensorboard import SummaryWriter + else: + logger.info( + 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) + from tensorboardX import SummaryWriter + tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) + else: + util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) + logger = logging.getLogger('base') + + #### create train and val dataloader + dataset_ratio = 200 # enlarge the size of each epoch + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + train_set = create_dataset(dataset_opt) + train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) + total_iters = int(opt['train']['niter']) + total_epochs = int(math.ceil(total_iters / train_size)) + if opt['dist']: + train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) + total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) + else: + train_sampler = None + train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) + if rank <= 0: + logger.info('Number of train images: {:,d}, iters: {:,d}'.format( + len(train_set), train_size)) + logger.info('Total epochs needed: {:d} for iters {:,d}'.format( + total_epochs, total_iters)) + elif phase == 'val': + val_set = create_dataset(dataset_opt) + val_loader = create_dataloader(val_set, dataset_opt, opt, None) + if rank <= 0: + logger.info('Number of val images in [{:s}]: {:d}'.format( + dataset_opt['name'], len(val_set))) + else: + raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) + assert train_loader is not None + assert val_loader is not None + + #### create model + model = create_model(opt) + + #### resume training + if resume_state: + logger.info('Resuming training from epoch: {}, iter: {}.'.format( + resume_state['epoch'], resume_state['iter'])) + + start_epoch = resume_state['epoch'] + current_step = resume_state['iter'] + model.resume_training(resume_state) # handle optimizers and schedulers + else: + current_step = 0 + start_epoch = 0 + + #### init online degradation function + prepro_train = util.SRMDPreprocessing(opt['scale'], random=True, l=opt['kernel_size'], add_noise=opt['train_noise'], + noise_high=opt['noise_high'] / 255., add_jpeg=opt['train_jpeg'], jpeg_low=opt['jpeg_low'], + rate_cln=-1, device=torch.device('cuda:{}'.format(device_id)), sig=opt['sig'], + sig1=opt['sig1'], sig2=opt['sig2'], theta=opt['theta'], + sig_min=opt['sig_min'], sig_max=opt['sig_max'], rate_iso=opt['rate_iso'], + is_training=True, sv_mode=0) + prepro_val = util.SRMDPreprocessing(opt['scale'], random=False, l=opt['kernel_size'], add_noise=opt['test_noise'], + noise_high=opt['noise'], add_jpeg=opt['test_jpeg'], jpeg_low=opt['jpeg'], + rate_cln=-1, device=torch.device('cuda:{}'.format(device_id)), sig=opt['sig'], + sig1=opt['sig1'], sig2=opt['sig2'], theta=opt['theta'], + sig_min=opt['sig_min'], sig_max=opt['sig_max'], rate_iso=opt['rate_iso'], + is_training=False, sv_mode=0) + + #### training + # mixed precision + scaler = torch.cuda.amp.GradScaler() + + logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) + for epoch in range(start_epoch, total_epochs + 1): + if opt['dist']: + train_sampler.set_epoch(epoch) + for _, train_data in enumerate(train_loader): + current_step += 1 + if current_step > total_iters: + break + + if train_data['LQ'].shape[2] == 1: + train_data['GT'] = train_data['GT'].to(torch.device('cuda:{}'.format(device_id))) + LR_img, LR_n_img, ker_map, kernel = prepro_train(train_data['GT'], kernel=True) + else: + LR_img, LR_n_img, ker_map, kernel = train_data['LQ'], torch.zeros(1, 1), torch.zeros(1, 1, 1) + + #### update learning rate, schedulers + model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) + + #### training + model.feed_data(train_data, LR_img, LR_n_img, ker_map, kernel) + model.optimize_parameters(current_step, scaler) + + #### log + if current_step % opt['logger']['print_freq'] == 0: + logs = model.get_current_log() + message = '