diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..bb305c2 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [2021][HCFlow Authors] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + diff --git a/README.md b/README.md index e3f6706..8f1b86b 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,90 @@ + # Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling (HCFlow, ICCV2021) -Official PyTorch code for Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling (HCFlow, ICCV2021) -# Stay tuned! The code is coming before 18th August. +This repository is the official PyTorch implementation of Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling +([arxiv](https://arxiv.org/pdf/2108.05301.pdf), [supp](https://github.com/JingyunLiang/HCFlow/releases/tag/v0.0)). + + +:rocket: :rocket: :rocket: **News**: + - Aug. 17, 2021: See our recent work for spatially variant kernel estimation: [Mutual Affine Network for Spatially Variant Kernel Estimation in Blind Image Super-Resolution (MANet), ICCV2021](https://github.com/JingyunLiang/MANet) + - Aug. 17, 2021: See our recent work for real-world image SR: [Designing a Practical Degradation Model for Deep Blind Image Super-Resolution (BSRGAN), ICCV2021](https://github.com/cszn/BSRGAN) + - Aug. 17, 2021: See our previous flow-based work: *[Flow-based Kernel Prior with Application to Blind Super-Resolution (FKP), CVPR2021](https://github.com/JingyunLiang/FKP).* + --- + +> Normalizing flows have recently demonstrated promising results for low-level vision tasks. For image super-resolution (SR), it learns to predict diverse photo-realistic high-resolution (HR) images from the low-resolution (LR) image rather than learning a deterministic mapping. For image rescaling, it achieves high accuracy by jointly modelling the downscaling and upscaling processes. While existing approaches employ specialized techniques for these two tasks, we set out to unify them in a single formulation. In this paper, we propose the hierarchical conditional flow (HCFlow) as a unified framework for image SR and image rescaling. More specifically, HCFlow learns a bijective mapping between HR and LR image pairs by modelling the distribution of the LR image and the rest high-frequency component simultaneously. In particular, the high-frequency component is conditional on the LR image in a hierarchical manner. To further enhance the performance, other losses such as perceptual loss and GAN loss are combined with the commonly used negative log-likelihood loss in training. Extensive experiments on general image SR, face image SR and image rescaling have demonstrated that the proposed HCFlow achieves state-of-the-art performance in terms of both quantitative metrics and visual quality. +>

+ >           +

+ +## Requirements +- Python 3.7, PyTorch == 1.7.1 +- Requirements: opencv-python, lpips, natsort, etc. +- Platforms: Ubuntu 16.04, cuda-11.0 + + +```bash +cd HCFlow-master +pip install -r requirements.txt +``` + +## Quick Run +To run the code with one command (without preparing data), run this command: +```bash +cd codes +# face image SR +python test_HCFLow.py --opt options/test/test_SR_CelebA_8X_HCFlow.yml + +# general image SR +python test_HCFLow.py --opt options/test/test_SR_DF2K_4X_HCFlow.yml + +# image rescaling +python test_HCFLow.py --opt options/test/test_Rescaling_DF2K_4X_HCFlow.yml +``` +--- + +## Data Preparation +The framework of this project is based on [MMSR](https://github.com/open-mmlab/mmediting) and [SRFlow](https://github.com/andreas128/SRFlow). To prepare data, put training and testing sets in `./datasets` as `./datasets/DIV2K/HR/0801.png`. Commonly used SR datasets can be downloaded [here](https://github.com/xinntao/BasicSR/blob/master/docs/DatasetPreparation.md#common-image-sr-datasets). +There are two ways for accerleration in data loading: First, one can use `./scripts/png2npy.py` to generate `.npy` files and use `data/GTLQnpy_dataset.py`. Second, one can use `.pklv4` dataset (*recommended*) and use `data/LRHR_PKL_dataset.py`. Please refer to [SRFlow](https://github.com/andreas128/SRFlow#dataset-how-to-train-on-your-own-data) for more details. Prepared datasets can be downloaded [here](http://data.vision.ee.ethz.ch/alugmayr/SRFlow/datasets.zip). + +## Training + +To train HCFlow for general image SR/ face image SR/ image rescaling, run this command: + +```bash +cd codes + +# face image SR +python train_HCFLow.py --opt options/train/train_SR_CelebA_8X_HCFlow.yml + +# general image SR +python train_HCFLow.py --opt options/train/train_SR_DF2K_4X_HCFlow.yml + +# image rescaling +python train_HCFLow.py --opt options/train/train_Rescaling_DF2K_4X_HCFlow.yml +``` +All trained models can be downloaded from [here](https://github.com/JingyunLiang/HCFlow/releases/tag/v0.0). + + +## Testing + +Please follow the **Quick Run** section. Just modify the dataset path in `test_HCFlow_*.yml`. + +## Results +We achieved state-of-the-art performance on general image SR, face image SR and image rescaling. +> +> +For more results, please refer to the [paper](https://arxiv.org/abs/2108.05301) and [supp](https://github.com/JingyunLiang/HCFlow/releases/tag/v0.0) for details. + +## Citation + @inproceedings{liang21hcflow, + title={Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling}, + author={Liang, Jingyun and Lugmayr, Andreas and Zhang, Kai and Danelljan, Martin and Van Gool, Luc and Timofte, Radu}, + booktitle={IEEE Conference on International Conference on Computer Vision}, + year={2021} + } + + +## License & Acknowledgement + +This project is released under the Apache 2.0 license. The codes are based on [MMSR](https://github.com/open-mmlab/mmediting), [SRFlow](https://github.com/andreas128/SRFlow), [IRN](https://github.com/pkuxmq/Invertible-Image-Rescaling) and [Glow-pytorch](https://github.com/chaiyujin/glow-pytorch). Please also follow their licenses. Thanks for their great works. + diff --git a/codes/data/GTLQ_dataset.py b/codes/data/GTLQ_dataset.py new file mode 100644 index 0000000..ffea133 --- /dev/null +++ b/codes/data/GTLQ_dataset.py @@ -0,0 +1,129 @@ +import random +import numpy as np +import cv2 +import lmdb +import torch +import torch.utils.data as data +import data.util as util +import sys +import os + +try: + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from data.util import imresize_np + from utils import util as utils +except ImportError: + pass + + +class GTLQDataset(data.Dataset): + ''' + Load HR-LR image pairs. + ''' + + def __init__(self, opt): + super(GTLQDataset, self).__init__() + self.opt = opt + self.LR_paths, self.GT_paths = None, None + self.LR_env, self.GT_env = None, None # environment for lmdb + self.scale = opt['scale'] + if self.opt['phase'] == 'train': + self.GT_size = opt['GT_size'] + self.LR_size = self.GT_size // self.scale + + # read image list from lmdb or image files + if opt['data_type'] == 'lmdb': + self.LR_paths, self.LR_sizes = util.get_image_paths(opt['data_type'], opt['dataroot_LQ']) + self.GT_paths, self.GT_sizes = util.get_image_paths(opt['data_type'], opt['dataroot_GT']) + elif opt['data_type'] == 'img': + self.LR_paths = util.get_image_paths(opt['data_type'], opt['dataroot_LQ']) # LR list + self.GT_paths = util.get_image_paths(opt['data_type'], opt['dataroot_GT']) # GT list + else: + print('Error: data_type is not matched in Dataset') + assert self.GT_paths, 'Error: GT paths are empty.' + if self.LR_paths and self.GT_paths: + assert len(self.LR_paths) == len( + self.GT_paths), 'GT and LR datasets have different number of images - {}, {}.'.format( + len(self.LR_paths), len(self.GT_paths)) + + def _init_lmdb(self): + # https://github.com/chainer/chainermn/issues/129 + self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, + meminit=False) + if self.opt['dataroot_LQ'] is not None: + self.LR_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, + meminit=False) + else: + self.LR_env = 'No lmdb input for LR' + + def __getitem__(self, index): + if self.opt['data_type'] == 'lmdb': + if (self.GT_env is None) or (self.LR_env is None): + self._init_lmdb() + + if self.opt['data_type'] == 'lmdb': + resolution = [int(s) for s in self.GT_sizes[index].split('_')] + else: + resolution = None + + + # loading code from srflow test + # img_GT = cv2.imread(GT_path)[:, :, [2, 1, 0]] + # img_GT = torch.Tensor(img_GT.transpose([2, 0, 1]).astype(np.float32)) / 255 + # img_LR = cv2.imread(LR_path)[:, :, [2, 1, 0]] + # pad_factor = 2 + # h, w, c = img_LR.shape + # img_LR = impad(img_LR, bottom=int(np.ceil(h / pad_factor) * pad_factor - h), + # right=int(np.ceil(w / pad_factor) * pad_factor - w)) + # img_LR = torch.Tensor(img_LR.transpose([2, 0, 1]).astype(np.float32)) / 255 + + + # get GT and LR image + GT_path = self.GT_paths[index] + LR_path = self.LR_paths[index] + # LR_path = GT_path.replace('HR', 'LR_bicubic/X4').replace('.png','x{}.png'.format(self.scale)) + img_GT = util.read_img(self.GT_env, GT_path, resolution) # return: Numpy float32, HWC, BGR, [0,1] + img_LR = util.read_img(self.LR_env, LR_path, resolution) + + + if self.opt['phase'] == 'train': + # crop + H, W, C = img_LR.shape + rnd_top_LR = random.randint(0, max(0, H - self.LR_size)) + rnd_left_LR = random.randint(0, max(0, W - self.LR_size)) + rnd_top_GT = rnd_top_LR * self.scale + rnd_left_GT = rnd_left_LR * self.scale + + img_GT = img_GT[rnd_top_GT:rnd_top_GT + self.GT_size, rnd_left_GT:rnd_left_GT + self.GT_size, :] + img_LR = img_LR[rnd_top_LR:rnd_top_LR + self.LR_size, rnd_left_LR:rnd_left_LR + self.LR_size, :] + + # augmentation - flip, rotate + img_GT, img_LR = util.augment([img_GT, img_LR], self.opt['use_flip'], + self.opt['use_rot'], self.opt['mode']) + + # change color space if necessary, deal with gray image + if self.opt['color']: + img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] + img_LR = util.channel_convert(img_LR.shape[2], self.opt['color'], [img_LR])[0] + + # BGR to RGB, HWC to CHW, numpy to tensor + if img_GT.shape[2] == 3: + img_GT = img_GT[:, :, [2, 1, 0]] + if img_LR.shape[2] == 3: + img_LR = img_LR[:, :, [2, 1, 0]] + img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() + img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float() + + + # modcrop + _, H, W = img_LR.size() + img_GT = img_GT[:, :H*self.scale, :W*self.scale] + + return {'LQ': img_LR, 'GT': img_GT, 'LQ_path': LR_path, 'GT_path': GT_path} + + def __len__(self): + return len(self.GT_paths) + + +def impad(img, top=0, bottom=0, left=0, right=0, color=255): + return np.pad(img, [(top, bottom), (left, right), (0, 0)], 'reflect') diff --git a/codes/data/GTLQnpy_dataset.py b/codes/data/GTLQnpy_dataset.py new file mode 100644 index 0000000..89eae1d --- /dev/null +++ b/codes/data/GTLQnpy_dataset.py @@ -0,0 +1,82 @@ +import random +import numpy as np +import cv2 +import lmdb +import torch +import torch.utils.data as data +import data.util as util +import sys +import os + +try: + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from data.util import imresize_np + from utils import util as utils +except ImportError: + pass + + +class GTLQnpyDataset(data.Dataset): + ''' + Load HR-LR image npy pairs. Make sure HR-LR images are in the same order. + ''' + + def __init__(self, opt): + super(GTLQnpyDataset, self).__init__() + self.opt = opt + self.LR_paths, self.GT_paths = None, None + self.scale = opt['scale'] + if self.opt['phase'] == 'train': + self.GT_size = opt['GT_size'] + self.LR_size = self.GT_size // self.scale + + self.LR_paths = util.get_image_paths(opt['data_type'], opt['dataroot_LQ']) # LR list + self.GT_paths = util.get_image_paths(opt['data_type'], opt['dataroot_GT']) # GT list + + assert self.GT_paths, 'Error: GT paths are empty.' + if self.LR_paths and self.GT_paths: + assert len(self.LR_paths) == len( + self.GT_paths), 'GT and LR datasets have different number of images - {}, {}.'.format( + len(self.LR_paths), len(self.GT_paths)) + + def __getitem__(self, index): + # get GT and LR image + GT_path = self.GT_paths[index] + # LR_path = self.LR_paths[index] + LR_path = GT_path.replace('DIV2K+Flickr2K_HR', 'DIV2K+Flickr2K_LR_bicubic/X4').replace('.npy','x{}.npy'.format(self.scale)) + img_GT = util.read_img_fromnpy(np.load(GT_path)) + img_LR = util.read_img_fromnpy(np.load(LR_path)) # return: Numpy float32, HWC, BGR, [0,1] + + if self.opt['phase'] == 'train': + # crop + H, W, C = img_LR.shape + rnd_top_LR = random.randint(0, max(0, H - self.LR_size)) + rnd_left_LR = random.randint(0, max(0, W - self.LR_size)) + rnd_top_GT = rnd_top_LR * self.scale + rnd_left_GT = rnd_left_LR * self.scale + + img_GT = img_GT[rnd_top_GT:rnd_top_GT + self.GT_size, rnd_left_GT:rnd_left_GT + self.GT_size, :] + img_LR = img_LR[rnd_top_LR:rnd_top_LR + self.LR_size, rnd_left_LR:rnd_left_LR + self.LR_size, :] + + # augmentation - flip, rotate + img_GT, img_LR = util.augment([img_GT, img_LR], self.opt['use_flip'], + self.opt['use_rot'], self.opt['mode']) + + # change color space if necessary, deal with gray image + if self.opt['color']: + img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] + img_LR = util.channel_convert(img_LR.shape[2], self.opt['color'], [img_LR])[0] + + # BGR to RGB, HWC to CHW, numpy to tensor + if img_GT.shape[2] == 3: + img_GT = img_GT[:, :, [2, 1, 0]] + if img_LR.shape[2] == 3: + img_LR = img_LR[:, :, [2, 1, 0]] + img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() + img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float() + + return {'LQ': img_LR, 'GT': img_GT, 'LQ_path': LR_path, 'GT_path': GT_path} + + def __len__(self): + return len(self.GT_paths) + diff --git a/codes/data/GTLQx_dataset.py b/codes/data/GTLQx_dataset.py new file mode 100644 index 0000000..1bc5c15 --- /dev/null +++ b/codes/data/GTLQx_dataset.py @@ -0,0 +1,129 @@ +import random +import numpy as np +import cv2 +import lmdb +import torch +import torch.utils.data as data +import data.util as util +import sys +import os + +try: + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from data.util import imresize_np + from utils import util as utils +except ImportError: + pass + + +class GTLQxDataset(data.Dataset): + ''' + Load HR-LR image pairs. + ''' + + def __init__(self, opt): + super(GTLQxDataset, self).__init__() + self.opt = opt + self.LR_paths, self.GT_paths = None, None + self.LR_env, self.GT_env = None, None # environment for lmdb + self.scale = opt['scale'] + if self.opt['phase'] == 'train': + self.GT_size = opt['GT_size'] + self.LR_size = self.GT_size // self.scale + + # read image list from lmdb or image files + if opt['data_type'] == 'lmdb': + self.LR_paths, self.LR_sizes = util.get_image_paths(opt['data_type'], opt['dataroot_LQ']) + self.GT_paths, self.GT_sizes = util.get_image_paths(opt['data_type'], opt['dataroot_GT']) + elif opt['data_type'] == 'img': + self.LR_paths = util.get_image_paths(opt['data_type'], opt['dataroot_LQ']) # LR list + self.GT_paths = util.get_image_paths(opt['data_type'], opt['dataroot_GT']) # GT list + else: + print('Error: data_type is not matched in Dataset') + assert self.GT_paths, 'Error: GT paths are empty.' + if self.LR_paths and self.GT_paths: + assert len(self.LR_paths) == len( + self.GT_paths), 'GT and LR datasets have different number of images - {}, {}.'.format( + len(self.LR_paths), len(self.GT_paths)) + + def _init_lmdb(self): + # https://github.com/chainer/chainermn/issues/129 + self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, + meminit=False) + if self.opt['dataroot_LQ'] is not None: + self.LR_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, + meminit=False) + else: + self.LR_env = 'No lmdb input for LR' + + def __getitem__(self, index): + if self.opt['data_type'] == 'lmdb': + if (self.GT_env is None) or (self.LR_env is None): + self._init_lmdb() + + if self.opt['data_type'] == 'lmdb': + resolution = [int(s) for s in self.GT_sizes[index].split('_')] + else: + resolution = None + + + # loading code from srflow test + # img_GT = cv2.imread(GT_path)[:, :, [2, 1, 0]] + # img_GT = torch.Tensor(img_GT.transpose([2, 0, 1]).astype(np.float32)) / 255 + # img_LR = cv2.imread(LR_path)[:, :, [2, 1, 0]] + # pad_factor = 2 + # h, w, c = img_LR.shape + # img_LR = impad(img_LR, bottom=int(np.ceil(h / pad_factor) * pad_factor - h), + # right=int(np.ceil(w / pad_factor) * pad_factor - w)) + # img_LR = torch.Tensor(img_LR.transpose([2, 0, 1]).astype(np.float32)) / 255 + + + # get GT and LR image + GT_path = self.GT_paths[index] + # LR_path = self.LR_paths[index] + LR_path = GT_path.replace('HR', 'LR_bicubic/X4').replace('.png','x{}.png'.format(self.scale)) + img_GT = util.read_img(self.GT_env, GT_path, resolution) # return: Numpy float32, HWC, BGR, [0,1] + img_LR = util.read_img(self.LR_env, LR_path, resolution) + + + if self.opt['phase'] == 'train': + # crop + H, W, C = img_LR.shape + rnd_top_LR = random.randint(0, max(0, H - self.LR_size)) + rnd_left_LR = random.randint(0, max(0, W - self.LR_size)) + rnd_top_GT = rnd_top_LR * self.scale + rnd_left_GT = rnd_left_LR * self.scale + + img_GT = img_GT[rnd_top_GT:rnd_top_GT + self.GT_size, rnd_left_GT:rnd_left_GT + self.GT_size, :] + img_LR = img_LR[rnd_top_LR:rnd_top_LR + self.LR_size, rnd_left_LR:rnd_left_LR + self.LR_size, :] + + # augmentation - flip, rotate + img_GT, img_LR = util.augment([img_GT, img_LR], self.opt['use_flip'], + self.opt['use_rot'], self.opt['mode']) + + # change color space if necessary, deal with gray image + if self.opt['color']: + img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] + img_LR = util.channel_convert(img_LR.shape[2], self.opt['color'], [img_LR])[0] + + # BGR to RGB, HWC to CHW, numpy to tensor + if img_GT.shape[2] == 3: + img_GT = img_GT[:, :, [2, 1, 0]] + if img_LR.shape[2] == 3: + img_LR = img_LR[:, :, [2, 1, 0]] + img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() + img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float() + + + # modcrop + _, H, W = img_LR.size() + img_GT = img_GT[:, :H*self.scale, :W*self.scale] + + return {'LQ': img_LR, 'GT': img_GT, 'LQ_path': LR_path, 'GT_path': GT_path} + + def __len__(self): + return len(self.GT_paths) + + +def impad(img, top=0, bottom=0, left=0, right=0, color=255): + return np.pad(img, [(top, bottom), (left, right), (0, 0)], 'reflect') diff --git a/codes/data/GT_dataset.py b/codes/data/GT_dataset.py new file mode 100644 index 0000000..7bd5294 --- /dev/null +++ b/codes/data/GT_dataset.py @@ -0,0 +1,131 @@ +import random +import numpy as np +import cv2 +import lmdb +import torch +import torch.utils.data as data +import data.util as util +import sys +import os + +try: + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from data.util import imresize_np + from utils import util as utils +except ImportError: + pass + + +class GTDataset(data.Dataset): + ''' + Load GT images only. 30s faster than LQGTKer (90s for 200iter). + ''' + + def __init__(self, opt): + super(GTDataset, self).__init__() + self.opt = opt + self.LR_paths, self.GT_paths = None, None + self.LR_env, self.GT_env = None, None # environment for lmdb + self.LR_size, self.GT_size = opt['LR_size'], opt['GT_size'] + + # read image list from lmdb or image files + if opt['data_type'] == 'lmdb': + self.LR_paths, self.LR_sizes = util.get_image_paths(opt['data_type'], opt['dataroot_LQ']) + self.GT_paths, self.GT_sizes = util.get_image_paths(opt['data_type'], opt['dataroot_GT']) + elif opt['data_type'] == 'img': + self.LR_paths = util.get_image_paths(opt['data_type'], opt['dataroot_LQ']) # LR list + self.GT_paths = util.get_image_paths(opt['data_type'], opt['dataroot_GT']) # GT list + else: + print('Error: data_type is not matched in Dataset') + assert self.GT_paths, 'Error: GT paths are empty.' + if self.LR_paths and self.GT_paths: + assert len(self.LR_paths) == len( + self.GT_paths), 'GT and LR datasets have different number of images - {}, {}.'.format( + len(self.LR_paths), len(self.GT_paths)) + self.random_scale_list = [1] + + def _init_lmdb(self): + # https://github.com/chainer/chainermn/issues/129 + self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, + meminit=False) + if self.opt['dataroot_LQ'] is not None: + self.LR_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, + meminit=False) + else: + self.LR_env = 'No lmdb input for LR' + + def __getitem__(self, index): + if self.opt['data_type'] == 'lmdb': + if (self.GT_env is None) or (self.LR_env is None): + self._init_lmdb() + + GT_path, LR_path = None, None + scale = self.opt['scale'] + GT_size = self.opt['GT_size'] + + # get GT image + GT_path = self.GT_paths[index] + if self.opt['data_type'] == 'lmdb': + resolution = [int(s) for s in self.GT_sizes[index].split('_')] + else: + resolution = None + img_GT = util.read_img(self.GT_env, GT_path, resolution) # return: Numpy float32, HWC, BGR, [0,1] + + # modcrop in the validation / test phase + img_GT = util.modcrop(img_GT, scale) + + # get LR image + if self.LR_paths: # LR exist + raise ValueError('GTker_dataset.py doesn Not allow LR input.') + + else: # down-sampling on-the-fly + # randomly scale during training + if self.opt['phase'] == 'train': + random_scale = random.choice(self.random_scale_list) + if random_scale != 1: + H_s, W_s, _ = img_GT.shape + H_s = _mod(H_s, random_scale, scale, GT_size) + W_s = _mod(W_s, random_scale, scale, GT_size) + img_GT = cv2.resize(np.copy(img_GT), (W_s, H_s), interpolation=cv2.INTER_LINEAR) + + # force to 3 channels + if img_GT.ndim == 2: + img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR) + + if self.opt['phase'] == 'train': + H, W, C = img_GT.shape + + # randomly crop on HR, more positions than first crop on LR and HR simultaneously + rnd_h_GT = random.randint(0, max(0, H - GT_size)) + rnd_w_GT = random.randint(0, max(0, W - GT_size)) + img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] + + # augmentation - flip, rotate + img_GT = util.augment(img_GT, self.opt['use_flip'], + self.opt['use_rot'], self.opt['mode']) + + # change color space if necessary, deal with gray image + if self.opt['color']: + img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] + + # BGR to RGB, HWC to CHW, numpy to tensor + if img_GT.shape[2] == 3: + img_GT = img_GT[:, :, [2, 1, 0]] + img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() + + if LR_path is None: + LR_path = GT_path + + # don't need LR because it's generated from HR batches. + img_LR = torch.ones(1, 1, 1) + + return {'LQ': img_LR, 'GT': img_GT, 'LQ_path': LR_path, 'GT_path': GT_path} + + def __len__(self): + return len(self.GT_paths) + + +def _mod(n, random_scale, scale, thres): + rlt = int(n * random_scale) + rlt = (rlt // scale) * scale + return thres if rlt < thres else rlt diff --git a/codes/data/LQ_dataset.py b/codes/data/LQ_dataset.py new file mode 100644 index 0000000..d11ae9c --- /dev/null +++ b/codes/data/LQ_dataset.py @@ -0,0 +1,106 @@ +import random +import numpy as np +import cv2 +import lmdb +import torch +import torch.nn.functional as F +import torch.utils.data as data +import data.util as util +import sys +import os + +try: + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from data.util import imresize_np + from utils import util as utils +except ImportError: + pass + + +class LQDataset(data.Dataset): + ''' + Load LR images only, e.g. real-world images + ''' + + def __init__(self, opt): + super(LQDataset, self).__init__() + self.opt = opt + self.LR_paths, self.GT_paths = None, None + self.LR_env, self.GT_env = None, None # environment for lmdb + self.LR_size, self.GT_size = opt['LR_size'], opt['GT_size'] + + # read image list from lmdb or image files + if opt['data_type'] == 'lmdb': + self.LR_paths, self.LR_sizes = util.get_image_paths(opt['data_type'], opt['dataroot_LQ']) + self.GT_paths, self.GT_sizes = util.get_image_paths(opt['data_type'], opt['dataroot_GT']) + elif opt['data_type'] == 'img': + self.LR_paths = util.get_image_paths(opt['data_type'], opt['dataroot_LQ']) # LR list + self.GT_paths = util.get_image_paths(opt['data_type'], opt['dataroot_GT']) # GT list + else: + print('Error: data_type is not matched in Dataset') + assert self.LR_paths, 'Error: LQ paths are empty.' + if self.LR_paths and self.GT_paths: + assert len(self.LR_paths) == len( + self.GT_paths), 'GT and LR datasets have different number of images - {}, {}.'.format( + len(self.LR_paths), len(self.GT_paths)) + self.random_scale_list = [1] + + def __getitem__(self, index): + + GT_path, LQ_path = None, None + + # get GT image + LQ_path = self.LR_paths[index] + img_LQ = util.read_img(None, LQ_path, None) # return: Numpy float32, HWC, BGR, [0,1] + + if self.GT_paths: # LR exist + raise ValueError('LQ_dataset.py doesn Not allow HR input.') + + else: + # force to 3 channels + if img_LQ.ndim == 2: + img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_GRAY2BGR) + + # change color space if necessary, deal with gray image + if self.opt['color']: + img_LQ = util.channel_convert(img_LQ.shape[2], self.opt['color'], [img_LQ])[0] + + # BGR to RGB, HWC to CHW, numpy to tensor + if img_LQ.shape[2] == 3: + img_LQ = img_LQ[:, :, [2, 1, 0]] + img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() + + if GT_path is None: + GT_path = LQ_path + + # don't need LR because it's generated from HR batches. + img_GT = torch.ones(1, 1, 1) + + # deal with the image margins for real-world images + img_LQ = img_LQ.unsqueeze(0) + x_gt = F.interpolate(img_LQ, scale_factor=self.opt['scale'], mode='nearest') + if self.opt['scale'] == 4: + real_crop = 3 + elif self.opt['scale'] == 2: + real_crop = 6 + elif self.opt['scale'] == 1: + real_crop = 11 + assert real_crop * self.opt['scale'] * 2 > self.opt['kernel_size'] + x_gt = F.pad(x_gt, pad=( + real_crop * self.opt['scale'], real_crop * self.opt['scale'], real_crop * self.opt['scale'], + real_crop * self.opt['scale']), mode='replicate') # 'constant', 'reflect', 'replicate' or 'circular + + kernel_gt, sigma_gt = utils.stable_batch_kernel(1, l=self.opt['kernel_size'], sig=10, sig1=0, sig2=0, + theta=0, rate_iso=1, scale=self.opt['scale'], + tensor=True) # generate kernel [BHW], y [BCHW] + + blur_layer = utils.BatchBlur(l=self.opt['kernel_size'], padmode='zero') + sample_layer = utils.BatchSubsample(scale=self.opt['scale']) + y_blurred = sample_layer(blur_layer(x_gt, kernel_gt)) + y_blurred[:, :, real_crop:-real_crop, real_crop:-real_crop] = img_LQ + img_LQ = y_blurred.squeeze(0) + + return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path} + + def __len__(self): + return len(self.LR_paths) diff --git a/codes/data/LRHR_PKL_dataset.py b/codes/data/LRHR_PKL_dataset.py new file mode 100644 index 0000000..5a25e8a --- /dev/null +++ b/codes/data/LRHR_PKL_dataset.py @@ -0,0 +1,193 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd. +# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode +# +# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE + +import os +# import subprocess +import torch.utils.data as data +import numpy as np +import time +import torch + +import pickle + + +class LRHR_PKLDataset(data.Dataset): + def __init__(self, opt): + super(LRHR_PKLDataset, self).__init__() + self.opt = opt + self.crop_size = opt.get("GT_size", None) + self.scale = None + self.random_scale_list = [1] + + hr_file_path = opt["dataroot_GT"] + lr_file_path = opt["dataroot_LQ"] + y_labels_file_path = opt['dataroot_y_labels'] + + gpu = True + augment = True + + self.use_flip = opt["use_flip"] if "use_flip" in opt.keys() else False + self.use_rot = opt["use_rot"] if "use_rot" in opt.keys() else False + self.use_crop = opt["use_crop"] if "use_crop" in opt.keys() else False + self.center_crop_hr_size = opt.get("center_crop_hr_size", None) + + n_max = opt["n_max"] if "n_max" in opt.keys() else int(1e8) + + t = time.time() + self.lr_images = self.load_pkls(lr_file_path, n_max) + self.hr_images = self.load_pkls(hr_file_path, n_max) + + min_val_hr = np.min([i.min() for i in self.hr_images[:20]]) + max_val_hr = np.max([i.max() for i in self.hr_images[:20]]) + + min_val_lr = np.min([i.min() for i in self.lr_images[:20]]) + max_val_lr = np.max([i.max() for i in self.lr_images[:20]]) + + t = time.time() - t + print("Loaded {} HR images with [{:.2f}, {:.2f}] in {:.2f}s from {}". + format(len(self.hr_images), min_val_hr, max_val_hr, t, hr_file_path)) + print("Loaded {} LR images with [{:.2f}, {:.2f}] in {:.2f}s from {}". + format(len(self.lr_images), min_val_lr, max_val_lr, t, lr_file_path)) + + self.gpu = gpu + self.augment = augment + + self.measures = None + + # todo: this is very slow (~15min) when using nn.DistributedDataParallel(), and we have to set n_worker=0 + + # # save as png + # import cv2 + # for i in range(400): + # img = self.hr_images[i] + # img = np.transpose(img, [1,2,0]) + # cv2.imwrite("/cluster/work/cvl/jinliang/log/srflow_experiments/ICCV21/baseline_results/CelebA_HR_8X/{}.png".format(i), img[:,:,[2, 1, 0]]) + # img = self.lr_images[i] + # img = np.transpose(img, [1,2,0]) + # cv2.imwrite("/cluster/work/cvl/jinliang/log/srflow_experiments/ICCV21/baseline_results/CelebA_LR_8X/{}.png".format(i), img[:,:,[2, 1, 0]]) + # raise NotImplementedError + + def load_pkls(self, path, n_max): + assert os.path.isfile(path), path + images = [] + with open(path, "rb") as f: + images += pickle.load(f) + assert len(images) > 0, path + images = images[:n_max] + images = [np.transpose(image, [2, 0, 1]) for image in images] + return images + + def __len__(self): + return len(self.hr_images) + + def __getitem__(self, item): + + hr = self.hr_images[item] + lr = self.lr_images[item] + + if self.scale == None: + self.scale = hr.shape[1] // lr.shape[1] + assert hr.shape[1] == self.scale * lr.shape[1], ('non-fractional ratio', lr.shape, hr.shape) + + if self.use_crop: + hr, lr = random_crop(hr, lr, self.crop_size, self.scale, self.use_crop) + + if self.center_crop_hr_size: + hr, lr = center_crop(hr, self.center_crop_hr_size), center_crop(lr, self.center_crop_hr_size // self.scale) + + if self.use_flip: + hr, lr = random_flip(hr, lr) + + if self.use_rot: + hr, lr = random_rotation(hr, lr) + + hr = hr / 255.0 + lr = lr / 255.0 + + if self.measures is None or np.random.random() < 0.05: + if self.measures is None: + self.measures = {} + self.measures['hr_means'] = np.mean(hr) + self.measures['hr_stds'] = np.std(hr) + self.measures['lr_means'] = np.mean(lr) + self.measures['lr_stds'] = np.std(lr) + + hr = torch.Tensor(hr) + lr = torch.Tensor(lr) + + # if self.gpu: + # hr = hr.cuda() + # lr = lr.cuda() + + return {'LQ': lr, 'GT': hr, 'LQ_path': str(item), 'GT_path': str(item)} + + def print_and_reset(self, tag): + m = self.measures + kvs = [] + for k in sorted(m.keys()): + kvs.append("{}={:.2f}".format(k, m[k])) + print("[KPI] " + tag + ": " + ", ".join(kvs)) + self.measures = None + + +def random_flip(img, seg): + random_choice = np.random.choice([True, False]) + img = img if random_choice else np.flip(img, 2).copy() + seg = seg if random_choice else np.flip(seg, 2).copy() + return img, seg + + +def random_rotation(img, seg): + random_choice = np.random.choice([0, 1, 3]) + img = np.rot90(img, random_choice, axes=(1, 2)).copy() + seg = np.rot90(seg, random_choice, axes=(1, 2)).copy() + return img, seg + + +def random_crop(hr, lr, size_hr, scale, random): + size_lr = size_hr // scale + + size_lr_x = lr.shape[1] + size_lr_y = lr.shape[2] + + start_x_lr = np.random.randint(low=0, high=(size_lr_x - size_lr) + 1) if size_lr_x > size_lr else 0 + start_y_lr = np.random.randint(low=0, high=(size_lr_y - size_lr) + 1) if size_lr_y > size_lr else 0 + + # LR Patch + lr_patch = lr[:, start_x_lr:start_x_lr + size_lr, start_y_lr:start_y_lr + size_lr] + + # HR Patch + start_x_hr = start_x_lr * scale + start_y_hr = start_y_lr * scale + hr_patch = hr[:, start_x_hr:start_x_hr + size_hr, start_y_hr:start_y_hr + size_hr] + + return hr_patch, lr_patch + + +def center_crop(img, size): + assert img.shape[1] == img.shape[2], img.shape + border_double = img.shape[1] - size + assert border_double % 2 == 0, (img.shape, size) + border = border_double // 2 + return img[:, border:-border, border:-border] + + +def center_crop_tensor(img, size): + assert img.shape[2] == img.shape[3], img.shape + border_double = img.shape[2] - size + assert border_double % 2 == 0, (img.shape, size) + border = border_double // 2 + return img[:, :, border:-border, border:-border] diff --git a/codes/data/__init__.py b/codes/data/__init__.py new file mode 100644 index 0000000..fd4ebb7 --- /dev/null +++ b/codes/data/__init__.py @@ -0,0 +1,54 @@ +'''create dataset and dataloader''' +import logging +import torch +import torch.utils.data + + +def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): + phase = dataset_opt['phase'] + if phase == 'train': + if opt['dist']: + world_size = torch.distributed.get_world_size() + num_workers = dataset_opt['n_workers'] + assert dataset_opt['batch_size'] % world_size == 0 + batch_size = dataset_opt['batch_size'] // world_size + shuffle = False + else: + num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids']) + batch_size = dataset_opt['batch_size'] + shuffle = True + return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, + num_workers=num_workers, sampler=sampler, drop_last=True, + pin_memory=True) + else: + return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, + pin_memory=True) + + +def create_dataset(dataset_opt): + mode = dataset_opt['mode'] + if mode == 'GT': # load HR image and generate LR on-the-fly + from data.GT_dataset import GTDataset as D + dataset = D(dataset_opt) + elif mode == 'GTLQ': # load generated HR-LR image pairs + from data.GTLQ_dataset import GTLQDataset as D + dataset = D(dataset_opt) + elif mode == 'GTLQx': # load generated HR-LR image pairs, and replace with x4 + from data.GTLQx_dataset import GTLQxDataset as D + dataset = D(dataset_opt) + elif mode == 'LQ': # load LR image for testing + from data.LQ_dataset import LQDataset as D + dataset = D(dataset_opt) + elif mode == 'LRHR_PKL': + from data.LRHR_PKL_dataset import LRHR_PKLDataset as D + dataset = D(dataset_opt) + elif mode == 'GTLQnpy': # load generated HR-LR image pairs + from data.GTLQnpy_dataset import GTLQnpyDataset as D + dataset = D(dataset_opt) + else: + raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) + + logger = logging.getLogger('base') + logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__, + dataset_opt['name'])) + return dataset diff --git a/codes/data/data_sampler.py b/codes/data/data_sampler.py new file mode 100644 index 0000000..ba9a72b --- /dev/null +++ b/codes/data/data_sampler.py @@ -0,0 +1,109 @@ +""" +Modified from torch.utils.data.distributed.DistributedSampler +Support enlarging the dataset for *iter-oriented* training, for saving time when restart the +dataloader after each epoch +""" +import math +import torch +from torch.utils.data.sampler import Sampler +import torch.distributed as dist + + +class DistIterSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + + .. note:: + Dataset is assumed to be of constant size. + + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + """ + + def __init__(self, dataset, num_replicas=None, rank=None, ratio=100): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(self.total_size, generator=g).tolist() #Returns a random permutation of integers from 0 to n - 1 + + dsize = len(self.dataset) + indices = [v % dsize for v in indices] + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch + + +class EnlargedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + Modified from torch.utils.data.distributed.DistributedSampler + Support enlarging the dataset for iteration-based training, for saving + time when restart the dataloader after each epoch + Args: + dataset (torch.utils.data.Dataset): Dataset used for sampling. + num_replicas (int | None): Number of processes participating in + the training. It is usually the world_size. + rank (int | None): Rank of the current process within num_replicas. + ratio (int): Enlarging ratio. Default: 1. + """ + + def __init__(self, dataset, num_replicas, rank, ratio=1): + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = math.ceil( + len(self.dataset) * ratio / self.num_replicas) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(self.total_size, generator=g).tolist() + + dataset_size = len(self.dataset) + indices = [v % dataset_size for v in indices] + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch \ No newline at end of file diff --git a/codes/data/util.py b/codes/data/util.py new file mode 100644 index 0000000..84666f0 --- /dev/null +++ b/codes/data/util.py @@ -0,0 +1,536 @@ +import os +import math +import pickle +import random +import numpy as np +import torch +import cv2 + +#################### +# Files & IO +#################### + +###################### get image path list ###################### +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.npy'] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def _get_paths_from_images(path): + '''get image path list from image folder''' + assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, '{:s} has no valid image file'.format(path) + return images + + +def _get_paths_from_lmdb(dataroot): + '''get image path list from lmdb meta info''' + meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), 'rb')) + paths = meta_info['keys'] + sizes = meta_info['resolution'] + if len(sizes) == 1: + sizes = sizes * len(paths) + return paths, sizes + + +def get_image_paths(data_type, dataroot): + '''get image path list + support lmdb or image files''' + paths, sizes = None, None + if data_type == 'lmdb': + if dataroot is not None: + paths, sizes = _get_paths_from_lmdb(dataroot) + return paths, sizes + elif data_type == 'img': + if dataroot is not None: + paths = sorted(_get_paths_from_images(dataroot)) + return paths + else: + raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type)) + + +###################### read images ###################### +def _read_img_lmdb(env, key, size): + '''read image from lmdb with key (w/ and w/o fixed size) + size: (C, H, W) tuple''' + with env.begin(write=False) as txn: + buf = txn.get(key.encode('ascii')) + img_flat = np.frombuffer(buf, dtype=np.uint8) + C, H, W = size + img = img_flat.reshape(H, W, C) + return img + + +def read_img(env, path, size=None): + '''read image by cv2 or from lmdb + return: Numpy float32, HWC, BGR, [0,1]''' + if env is None: # img + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + else: + img = _read_img_lmdb(env, path, size) + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) # added by jinliang + # img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + +def read_img_fromcv2(img): + '''transform image in cv2 to the same output as read_img + return: Numpy float32, HWC, BGR, [0,1]''' + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + +def read_img_fromnpy(img): + '''read image in sio to the same output as read_img + return: Numpy float32, HWC, BGR, [0,1]''' + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) # added by jinliang + # img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img +#################### +# image processing +# process on numpy image +#################### + + +def augment(img, hflip=True, rot=True, mode=None): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + if mode == 'LQ' or mode == 'GT' or mode == 'SRker' or mode == 'GTker' or mode == 'GTker_memory' or mode == 'GTLQnpy_noisy': + return _augment(img) + elif mode == 'GTLQ' or mode == 'GTLQnpy': + return [_augment(I) for I in img] + else: + raise NotImplementedError('{} dataloader has not been implemented yet.'.format(mode)) + + +def augment_flow(img_list, flow_list, hflip=True, rot=True): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + def _augment_flow(flow): + if hflip: + flow = flow[:, ::-1, :] + flow[:, :, 0] *= -1 + if vflip: + flow = flow[::-1, :, :] + flow[:, :, 1] *= -1 + if rot90: + flow = flow.transpose(1, 0, 2) + flow = flow[:, :, [1, 0]] + return flow + + rlt_img_list = [_augment(img) for img in img_list] + rlt_flow_list = [_augment_flow(flow) for flow in flow_list] + + return rlt_img_list, rlt_flow_list + + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + if in_c == 3 and tar_type == 'gray': # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'y': # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +def rgb2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + '''bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + '''same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + +def modcrop(img_in, scale): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r, :] + else: + raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + return img + + +#################### +# Functions +#################### + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5 * absx3 - 2.5 * absx2 + 1) * ( + (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (( + (absx > 1) * (absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( + 1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: CHW RGB [0,1] + # output: CHW RGB [0,1] w/o round + + in_C, in_H, in_W = img.size() + _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + out_1[0, i, :] = img_aug[0, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + out_1[1, i, :] = img_aug[1, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + out_1[2, i, :] = img_aug[2, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + out_2[0, :, i] = out_1_aug[0, :, idx:idx + kernel_width].mv(weights_W[i]) + out_2[1, :, i] = out_1_aug[1, :, idx:idx + kernel_width].mv(weights_W[i]) + out_2[2, :, i] = out_1_aug[2, :, idx:idx + kernel_width].mv(weights_W[i]) + + return out_2 + + +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC BGR [0,1] + # output: HWC BGR [0,1] w/o round + img = torch.from_numpy(img) + + in_H, in_W, in_C = img.size() + _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + out_1[i, :, 0] = img_aug[idx:idx + kernel_width, :, 0].transpose(0, 1).mv(weights_H[i]) + out_1[i, :, 1] = img_aug[idx:idx + kernel_width, :, 1].transpose(0, 1).mv(weights_H[i]) + out_1[i, :, 2] = img_aug[idx:idx + kernel_width, :, 2].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + out_2[:, i, 0] = out_1_aug[:, idx:idx + kernel_width, 0].mv(weights_W[i]) + out_2[:, i, 1] = out_1_aug[:, idx:idx + kernel_width, 1].mv(weights_W[i]) + out_2[:, i, 2] = out_1_aug[:, idx:idx + kernel_width, 2].mv(weights_W[i]) + + return out_2.numpy() + + + ### Load data kernel map ### +def load_ker_map_list(path): + real_ker_map_list = [] + batch_kermap = torch.load(path) + size_kermap = batch_kermap.size() + m = size_kermap[0] + for i in range(m): + real_ker_map_list.append(batch_kermap[i]) + + return real_ker_map_list + + +def test_patchwise(model, L, patchsize=48, overlapsize=16, sf=1): + ''' + Args: + model: deep model + L: input low-quality image, bxcxhxw + patchsize: + overlapsize: + sf: scale factor, 1 for denoising + Return: + E: output estimated image, bxcx(h*sf)x(w*sf) + + see another version for segmentation here: https://github.com/fudan-zvg/SETR/blob/main/mmseg/models/segmentors/encoder_decoder.py#L169 + ''' + b, c, h, w = L.size() + stride = patchsize-overlapsize + h_idx_list = list(range(0, h-patchsize, stride)) + [h-patchsize] + w_idx_list = list(range(0, w-patchsize, stride)) + [w-patchsize] + E = torch.zeros(b, c, h*sf, w*sf).type_as(L) + W = torch.zeros_like(E) + for h_idx in h_idx_list: + for w_idx in w_idx_list: + in_patch = L[..., h_idx:h_idx+patchsize, w_idx:w_idx+patchsize] + out_patch = model(in_patch) + E[..., h_idx*sf:(h_idx+patchsize)*sf, w_idx*sf:(w_idx+patchsize)*sf].add_(out_patch) + W[..., h_idx*sf:(h_idx+patchsize)*sf, w_idx*sf:(w_idx+patchsize)*sf].add_(1) + return E.div_(W) + + +if __name__ == '__main__': + # test imresize function + # read images + img = cv2.imread('test.png') + img = img * 1.0 / 255 + img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() + # imresize + scale = 1 / 4 + import time + total_time = 0 + for i in range(10): + start_time = time.time() + rlt = imresize(img, scale, antialiasing=True) + use_time = time.time() - start_time + total_time += use_time + print('average time: {}'.format(total_time / 10)) + + import torchvision.utils + torchvision.utils.save_image((rlt * 255).round() / 255, 'rlt.png', nrow=1, padding=0, + normalize=False) diff --git a/codes/models/HCFlow_Rescaling_model.py b/codes/models/HCFlow_Rescaling_model.py new file mode 100644 index 0000000..20d3c04 --- /dev/null +++ b/codes/models/HCFlow_Rescaling_model.py @@ -0,0 +1,521 @@ +# base model for SRFlow, a flow-based model +import logging +from collections import OrderedDict +import torch +import torch.nn as nn +import torch.nn.init as init +import torch.nn.functional as F +from torch.nn.parallel import DataParallel, DistributedDataParallel +import models.networks as networks +import models.lr_scheduler as lr_scheduler +from models.modules.loss import GANLoss +from .base_model import BaseModel +import utils.util as util +from models.modules.loss import ReconstructionLoss +from models.modules.Basic import Quantization +# from torch.cuda.amp import autocast as autocast + +logger = logging.getLogger('base') + + +class HCFLowRescalingModel(BaseModel): + def __init__(self, opt, step): + super(HCFLowRescalingModel, self).__init__(opt) + self.opt = opt + + self.hr_size = util.opt_get(opt, ['datasets', 'train', 'GT_size'], 160) + self.lr_size = self.hr_size // opt['scale'] + + if opt['dist']: + self.rank = torch.distributed.get_rank() + else: + self.rank = -1 # non dist training + + # define network and load pretrained models + self.netG = networks.define_G(opt, step).to(self.device) + if opt['dist']: + self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) + else: + self.netG = DataParallel(self.netG) + + self.Quantization = Quantization() + + if self.is_train: + train_opt = opt['train'] + self.netG.train() + + # NLL weight + # self.cri_pix_hr = ReconstructionLoss(losstype=train_opt['pixel_criterion_hr']) + # self.cri_pix_lr = ReconstructionLoss(losstype=train_opt['pixel_criterion_lr']) + self.eps_std_reverse = train_opt['eps_std_reverse'] + + if train_opt['pixel_weight_hr'] > 0: + loss_type = train_opt['pixel_criterion_hr'] + if loss_type == 'l1': + self.cri_pix_hr = nn.L1Loss().to(self.device) + elif loss_type == 'l2': + self.cri_pix_hr = nn.MSELoss().to(self.device) + else: + raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) + self.l_pix_w_hr = train_opt['pixel_weight_hr'] + else: + logger.info('Remove HR pixel loss.') + self.cri_pix_hr = None + + if train_opt['pixel_weight_lr'] > 0: + loss_type = train_opt['pixel_criterion_lr'] + if loss_type == 'l1': + self.cri_pix_lr = nn.L1Loss().to(self.device) + elif loss_type == 'l2': + self.cri_pix_lr = nn.MSELoss().to(self.device) + else: + raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) + self.l_pix_w_lr = train_opt['pixel_weight_lr'] + else: + logger.info('Remove LR pixel loss.') + self.cri_pix_lr = None + + self.l_pix_w_hr = train_opt['pixel_weight_hr'] + self.l_pix_w_lr = train_opt['pixel_weight_lr'] + self.l_w_z = train_opt['weight_z'] + + # HR feature loss + if train_opt['feature_weight'] > 0: + l_fea_type = train_opt['feature_criterion'] + if l_fea_type == 'l1': + self.cri_fea = nn.L1Loss().to(self.device) + elif l_fea_type == 'l2': + self.cri_fea = nn.MSELoss().to(self.device) + else: + raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) + self.l_fea_w = train_opt['feature_weight'] + + # load VGG perceptual loss + self.netF = networks.define_F(opt, use_bn=False).to(self.device) + if opt['dist']: + self.netF = DistributedDataParallel(self.netF, device_ids=[torch.cuda.current_device()]) + else: + self.netF = DataParallel(self.netF) + else: + logger.info('Remove feature loss.') + self.cri_fea = None + + # HR GAN loss + # put here to be compatible with PSNR version + self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 + self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 + if train_opt['gan_weight'] > 0: + self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) + self.l_gan_w = train_opt['gan_weight'] + + # define GAN Discriminator + self.netD = networks.define_D(opt).to(self.device) + if opt['dist']: + self.netD = DistributedDataParallel(self.netD, device_ids=[torch.cuda.current_device()]) + else: + self.netD = DataParallel(self.netD) + self.netD.train() + else: + logger.info('Remove GAN loss.') + self.cri_gan = None + + # gradient clip & norm + self.max_grad_clip = util.opt_get(train_opt, ['max_grad_clip']) + self.max_grad_norm = util.opt_get(train_opt, ['max_grad_norm']) + + # optimizers + # G + wd_G = util.opt_get(train_opt, ['weight_decay_G'], 0) + optim_params = [] + for k, v in self.netG.named_parameters(): # can optimize for a part of the model + if v.requires_grad: + # if v.requires_grad and ('additional_flow_steps' in k or 'additional_feature_steps' in k): # fixmainflow + # if v.requires_grad and ('additional_flow_steps' in k): # fix mainflowRRDB + optim_params.append(v) + else: + v.requires_grad = False + if self.rank <= 0: + logger.warning('Params [{:s}] will not optimize.'.format(k)) + + self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], + weight_decay=wd_G, + betas=(train_opt['beta1'], train_opt['beta2'])) + self.optimizers.append(self.optimizer_G) + + # D + if self.cri_gan: + wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 + self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], + weight_decay=wd_D, + betas=(train_opt['beta1_D'], train_opt['beta2_D'])) + self.optimizers.append(self.optimizer_D) + + # schedulers + if train_opt['lr_scheme'] == 'MultiStepLR': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], + restarts=train_opt['restarts'], + weights=train_opt['restart_weights'], + gamma=train_opt['lr_gamma'], + clear_state=train_opt['clear_state'])) + elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.CosineAnnealingLR_Restart( + optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], + restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) + else: + print('MultiStepLR learning rate scheme is enough.') + + self.log_dict = OrderedDict() + + # val + if 'val' in opt: + self.heats = opt['val']['heats'] + self.n_sample = opt['val']['n_sample'] + self.sr_mode = opt['val']['sr_mode'] + + self.print_network() # print network + self.load() # load G and D if needed + + def init_model(self, scale=0.1): + # Common practise for initialization. + for layer in self.netG.modules(): + if isinstance(layer, nn.Conv2d): + init.kaiming_normal_(layer.weight, a=0, mode='fan_in') + layer.weight.data *= scale # for residual block + if layer.bias is not None: + layer.bias.data.zero_() + elif isinstance(layer, nn.Linear): + init.kaiming_normal_(layer.weight, a=0, mode='fan_in') + layer.weight.data *= scale + if layer.bias is not None: + layer.bias.data.zero_() + elif isinstance(layer, nn.BatchNorm2d): + init.constant_(layer.weight, 1) + init.constant_(layer.bias.data, 0.0) + + def feed_data(self, data, need_GT=True): + self.var_L = data['LQ'].to(self.device) # LQ + if need_GT: + self.real_H = data['GT'].to(self.device) # GT + + def optimize_parameters(self, step): + # special initialization for actnorm; don't initialize when fine-tuned with gan + if step < self.opt['network_G']['act_norm_start_step'] and not (self.cri_pix_hr or self.cri_gan): + self.set_actnorm_init(inited=False) + + # (1) G + fake_H = None + if (step % self.D_update_ratio == 0 and step > self.D_init_iters) or (not self.cri_gan): + self.optimizer_G.zero_grad() + + fake_LR, fake_z1, fake_z2 = self.netG(hr=self.real_H, lr=self.var_L, u=None, reverse=False) + l_g_lr = self.l_pix_w_lr * self.cri_pix_lr(fake_LR, self.var_L) + l_g_z = self.l_w_z * (torch.cat([fake_z1.flatten(), fake_z2.flatten()],0)**2).mean() + + fake_LR = self.Quantization(fake_LR) + fake_H = self.netG(lr=fake_LR, z=None, u=None, eps_std=self.eps_std_reverse, reverse=True) + l_g_hr = self.l_pix_w_hr * self.cri_pix_hr(fake_H, self.real_H) + + l_g_total = 0 + if torch.isfinite(l_g_lr): + l_g_total += l_g_lr + if torch.isfinite(l_g_z): + l_g_total += l_g_z + if torch.isfinite(l_g_hr): + l_g_total += l_g_hr + + self.log_dict['l_g_lr'] = l_g_lr.item() + self.log_dict['l_g_z'] = l_g_z.item() + self.log_dict['l_g_hr'] = l_g_hr.item() + + + ######################## + # feature loss + if self.cri_fea: + real_fea = self.netF(self.real_H).detach() + fake_fea = self.netF(fake_H) + l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) + if torch.isfinite(l_g_fea): + l_g_total += l_g_fea + self.log_dict['l_g_fea'] = l_g_fea.item() + + # gan loss + if self.cri_gan: + for p in self.netD.parameters(): + p.requires_grad = False + + pred_g_fake = self.netD(fake_H) + if self.opt['train']['gan_type'] in ['gan', 'lsgan', 'wgangp']: + l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) + elif self.opt['train']['gan_type'] == 'ragan': + pred_d_real = self.netD(self.real_H).detach() + l_g_gan = self.l_gan_w * ( self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 + if torch.isfinite(l_g_gan): + l_g_total += l_g_gan + self.log_dict['l_g_gan'] = l_g_gan.item() + + if l_g_total != 0: + l_g_total.backward() + self.gradient_clip() + self.optimizer_G.step() + + # (2) D + if self.cri_gan: + self.optimizer_G.zero_grad() # can help save memory + + for p in self.netD.parameters(): + p.requires_grad = True + + # initialize D + if fake_H is None: + with torch.no_grad(): + fake_H = self.netG(lr=self.var_L, z=None, u=None, eps_std=self.eps_std_reverse, reverse=True) + + self.optimizer_D.zero_grad() + pred_d_real = self.netD(self.real_H) + pred_d_fake = self.netD(fake_H.detach()) # detach to avoid BP to G + if self.opt['train']['gan_type'] in ['gan', 'lsgan', 'wgangp']: + l_d_real = self.cri_gan(pred_d_real, True) + l_d_fake = self.cri_gan(pred_d_fake, False) + l_d_total = l_d_real + l_d_fake + elif self.opt['train']['gan_type'] == 'ragan': + l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) + l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) + l_d_total = (l_d_real + l_d_fake) / 2 + + self.log_dict['l_d_real'] = l_d_real.item() + self.log_dict['l_d_fake'] = l_d_fake.item() + self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) + self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) + + if torch.isfinite(l_d_total): + l_d_total.backward() + self.optimizer_D.step() + + def gradient_clip(self): + # gradient clip & norm, is not used in SRFlow + if self.max_grad_clip is not None: + torch.nn.utils.clip_grad_value_(self.netG.parameters(), self.max_grad_clip) + if self.max_grad_norm is not None: + torch.nn.utils.clip_grad_norm_(self.netG.parameters(), self.max_grad_norm) + + def test(self): + self.netG.eval() + self.fake_H = {} + + with torch.no_grad(): + # hr->lr+z, calculate nll + self.fake_L_from_H, fake_z1, fake_z2 = self.netG(hr=self.real_H, lr=self.var_L, u=None, reverse=False, training=False) + self.fake_L_from_H = self.Quantization(self.fake_L_from_H) + + # lr+z->hr + for heat in self.heats: + for sample in range(self.n_sample): + # z = self.get_z(heat, seed=1, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape) + self.fake_H[(heat, sample)] = self.netG(lr=self.fake_L_from_H, + z=None, u=None, eps_std=heat, reverse=True, training=False) + + self.netG.train() + + return fake_z1.mean().item() + + def get_encode_nll(self, lq, hr, y_label=None): + self.netG.eval() + with torch.no_grad(): + _, nll, _ = self.netG(hr=hr, lr=lq, reverse=False, y_label=y_label) + self.netG.train() + return nll.mean().item() + + def get_sr(self, lq, heat=None, seed=None, z=None, epses=None, y_label=None): + return self.get_sr_with_z(lq, heat, seed, z, epses, y_label=y_label)[0] + + def get_encode_z(self, lq, hr, epses=None, add_gt_noise=True, y_label=None): + self.netG.eval() + with torch.no_grad(): + z, _, _ = self.netG(hr=hr, lr=lq, reverse=False, epses=epses, add_gt_noise=add_gt_noise, y_label=y_label) + self.netG.train() + return z + + def get_encode_z_and_nll(self, lq, hr, epses=None, y_label=None): + self.netG.eval() + with torch.no_grad(): + z, nll, _ = self.netG(hr=hr, lr=lq, reverse=False, epses=epses, y_label=y_label) + self.netG.train() + return z, nll + + def get_sr_with_z(self, lq, heat=None, seed=None, z=None, epses=None, y_label=None): + self.netG.eval() + + z = self.get_z(heat, seed, batch_size=lq.shape[0], lr_shape=lq.shape, + y_label=None) if z is None and epses is None else z + + with torch.no_grad(): + sr, logdet = self.netG(lr=lq, z=z, eps_std=heat, reverse=True, epses=epses, y_label=None) + self.netG.train() + return sr, z + + def get_z(self, heat, seed=None, batch_size=1, lr_shape=None, y_label=None): + if y_label is None: + pass + if seed: torch.manual_seed(seed) + if util.opt_get(self.opt, ['network_G', 'flowDownsampler', 'splitOff', 'enable']): + C, H, W = lr_shape[1], lr_shape[2], lr_shape[3] + + size = (batch_size, C, H, W) + if heat == 0: + z = torch.zeros(size) + else: + z = torch.normal(mean=0, std=heat, size=size) + else: + L = util.opt_get(self.opt, ['network_G', 'flow', 'L']) or 3 + fac = 2 ** (L - 3) + z_size = int(self.lr_size // (2 ** (L - 3))) + z = torch.normal(mean=0, std=heat, size=(batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size)) + return z.to(self.device) + + def get_current_log(self): + return self.log_dict + + def get_current_visuals(self, need_GT=True): + out_dict = OrderedDict() + out_dict['LQ'] = self.var_L.detach()[0].float().cpu() + for heat in self.heats: + for i in range(self.n_sample): + out_dict[('SR', heat, i)] = self.fake_H[(heat, i)].detach()[0].float().cpu() + + if need_GT: + out_dict['GT'] = self.real_H.detach()[0].float().cpu() + out_dict['LQ_fromH'] = self.fake_L_from_H.detach()[0].float().cpu() + + return out_dict + + def print_network(self): + s, n = self.get_network_description(self.netG) + if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): + net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, + self.netG.module.__class__.__name__) + else: + net_struc_str = '{}'.format(self.netG.__class__.__name__) + if self.rank <= 0: + logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) + logger.info(s) + + if self.is_train: + # Discriminator + if self.cri_gan: + s, n = self.get_network_description(self.netD) + if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD, + DistributedDataParallel): + net_struc_str = '{} - {}'.format(self.netD.__class__.__name__, + self.netD.module.__class__.__name__) + else: + net_struc_str = '{}'.format(self.netD.__class__.__name__) + if self.rank <= 0: + logger.info('Network D structure: {}, with parameters: {:,d}'.format( + net_struc_str, n)) + logger.info(s) + + # F, Perceptual Network + if self.cri_fea: + s, n = self.get_network_description(self.netF) + if isinstance(self.netF, nn.DataParallel) or isinstance( + self.netF, DistributedDataParallel): + net_struc_str = '{} - {}'.format(self.netF.__class__.__name__, + self.netF.module.__class__.__name__) + else: + net_struc_str = '{}'.format(self.netF.__class__.__name__) + if self.rank <= 0: + logger.info('Network F structure: {}, with parameters: {:,d}'.format( + net_struc_str, n)) + logger.info(s) + + def load(self): + # resume training automatically if resume_state=='auto' + _, get_resume_model_path = util.get_resume_paths(self.opt) + if get_resume_model_path is not None: + logger.info('Automatically loading model for G [{:s}] ...'.format(get_resume_model_path)) + self.load_network(get_resume_model_path, self.netG, strict=True, submodule=None) + self.set_actnorm_init(inited=True) + + if self.is_train and self.cri_gan: + get_resume_model_path = get_resume_model_path.replace('_G.pth', '_D.pth') + logger.info('Automatically loading model for D [{:s}] ...'.format(get_resume_model_path)) + self.load_network(get_resume_model_path, self.netD, strict=True, submodule=None) + return + + # resume training according to given paths (pretrain path has been overrided by resume path) + if self.opt.get('path') is not None: + load_path_G = self.opt['path']['pretrain_model_G'] + if load_path_G is not None: + logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) + self.load_network(load_path_G, self.netG, self.opt['path'].get('strict_load', True)) + self.set_actnorm_init(inited=True) + + if self.is_train and self.cri_gan: + load_path_D = self.opt['path']['pretrain_model_D'] + if load_path_D is not None: + logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) + self.load_network(load_path_D, self.netD, self.opt['path'].get('strict_load', True)) + + def save(self, iter_label): + self.save_network(self.netG, 'G', iter_label) + if self.cri_gan: + self.save_network(self.netD, 'D', iter_label) + + def set_actnorm_init(self, inited=True): + for name, m in self.netG.named_modules(): + if (m.__class__.__name__.find("ActNorm") >= 0): + m.inited = inited + + # from glow + def generate_z(self, img): + self.eval() + B = self.hparams.Train.batch_size + x = img.unsqueeze(0).repeat(B, 1, 1, 1).cuda() + z,_, _ = self(x) + self.train() + return z[0].detach().cpu().numpy() + + def generate_attr_deltaz(self, dataset): + assert "y_onehot" in dataset[0] + self.eval() + with torch.no_grad(): + B = self.hparams.Train.batch_size + N = len(dataset) + attrs_pos_z = [[0, 0] for _ in range(self.y_classes)] + attrs_neg_z = [[0, 0] for _ in range(self.y_classes)] + for i in tqdm(range(0, N, B)): + j = min([i + B, N]) + # generate z for data from i to j + xs = [dataset[k]["x"] for k in range(i, j)] + while len(xs) < B: + xs.append(dataset[0]["x"]) + xs = torch.stack(xs).cuda() + zs, _, _ = self(xs) + for k in range(i, j): + z = zs[k - i].detach().cpu().numpy() + # append to different attrs + y = dataset[k]["y_onehot"] + for ai in range(self.y_classes): + if y[ai] > 0: + attrs_pos_z[ai][0] += z + attrs_pos_z[ai][1] += 1 + else: + attrs_neg_z[ai][0] += z + attrs_neg_z[ai][1] += 1 + # break + deltaz = [] + for ai in range(self.y_classes): + if attrs_pos_z[ai][1] == 0: + attrs_pos_z[ai][1] = 1 + if attrs_neg_z[ai][1] == 0: + attrs_neg_z[ai][1] = 1 + z_pos = attrs_pos_z[ai][0] / float(attrs_pos_z[ai][1]) + z_neg = attrs_neg_z[ai][0] / float(attrs_neg_z[ai][1]) + deltaz.append(z_pos - z_neg) + self.train() + return deltaz \ No newline at end of file diff --git a/codes/models/HCFlow_SR_model.py b/codes/models/HCFlow_SR_model.py new file mode 100644 index 0000000..7ed42e4 --- /dev/null +++ b/codes/models/HCFlow_SR_model.py @@ -0,0 +1,510 @@ +# base model for HCFlow +import logging +from collections import OrderedDict +import torch +import torch.nn as nn +import torch.nn.init as init +import torch.nn.functional as F +from torch.nn.parallel import DataParallel, DistributedDataParallel +import models.networks as networks +import models.lr_scheduler as lr_scheduler +from models.modules.loss import GANLoss +from .base_model import BaseModel +import utils.util as util + +logger = logging.getLogger('base') + + +class HCFlowSRModel(BaseModel): + def __init__(self, opt, step): + super(HCFlowSRModel, self).__init__(opt) + self.opt = opt + + self.hr_size = util.opt_get(opt, ['datasets', 'train', 'GT_size'], 160) + self.lr_size = self.hr_size // opt['scale'] + + if opt['dist']: + self.rank = torch.distributed.get_rank() + else: + self.rank = -1 # non dist training + + # define network and load pretrained models + self.netG = networks.define_G(opt, step).to(self.device) + if opt['dist']: + self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) + else: + self.netG = DataParallel(self.netG) + + if self.is_train: + train_opt = opt['train'] + self.netG.train() + + # NLL weight + self.l_nll_w = train_opt['nll_weight'] + self.eps_std_reverse = train_opt['eps_std_reverse'] + + # HR pixel loss + if train_opt['pixel_weight_hr'] > 0: + loss_type = train_opt['pixel_criterion_hr'] + if loss_type == 'l1': + self.cri_pix_hr = nn.L1Loss().to(self.device) + elif loss_type == 'l2': + self.cri_pix_hr = nn.MSELoss().to(self.device) + else: + raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) + self.l_pix_w_hr = train_opt['pixel_weight_hr'] + else: + logger.info('Remove HR pixel loss.') + self.cri_pix_hr = None + + # HR feature loss + if train_opt['feature_weight'] > 0: + l_fea_type = train_opt['feature_criterion'] + if l_fea_type == 'l1': + self.cri_fea = nn.L1Loss().to(self.device) + elif l_fea_type == 'l2': + self.cri_fea = nn.MSELoss().to(self.device) + else: + raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) + self.l_fea_w = train_opt['feature_weight'] + + # load VGG perceptual loss + self.netF = networks.define_F(opt, use_bn=False).to(self.device) + if opt['dist']: + self.netF = DistributedDataParallel(self.netF, device_ids=[torch.cuda.current_device()]) + else: + self.netF = DataParallel(self.netF) + else: + logger.info('Remove feature loss.') + self.cri_fea = None + + # HR GAN loss + # put here to be compatible with PSNR version + self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 + self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 + if train_opt['gan_weight'] > 0: + self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) + self.l_gan_w = train_opt['gan_weight'] + + # define GAN Discriminator + self.netD = networks.define_D(opt).to(self.device) + if opt['dist']: + self.netD = DistributedDataParallel(self.netD, device_ids=[torch.cuda.current_device()]) + else: + self.netD = DataParallel(self.netD) + self.netD.train() + else: + logger.info('Remove GAN loss.') + self.cri_gan = None + + # gradient clip & norm + self.max_grad_clip = util.opt_get(train_opt, ['max_grad_clip']) + self.max_grad_norm = util.opt_get(train_opt, ['max_grad_norm']) + + # optimizers + # G + wd_G = util.opt_get(train_opt, ['weight_decay_G'], 0) + optim_params = [] + for k, v in self.netG.named_parameters(): # can optimize for a part of the model + if v.requires_grad: + # if v.requires_grad and ('additional_flow_steps' in k or 'additional_feature_steps' in k): # fixmainflow + # if v.requires_grad and ('additional_flow_steps' in k): # fix mainflowRRDB + optim_params.append(v) + else: + v.requires_grad = False + if self.rank <= 0: + logger.warning('Params [{:s}] will not optimize.'.format(k)) + + self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], + weight_decay=wd_G, + betas=(train_opt['beta1'], train_opt['beta2'])) + self.optimizers.append(self.optimizer_G) + + # D + if self.cri_gan: + wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 + self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], + weight_decay=wd_D, + betas=(train_opt['beta1_D'], train_opt['beta2_D'])) + self.optimizers.append(self.optimizer_D) + + # schedulers + if train_opt['lr_scheme'] == 'MultiStepLR': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], + restarts=train_opt['restarts'], + weights=train_opt['restart_weights'], + gamma=train_opt['lr_gamma'], + clear_state=train_opt['clear_state'])) + elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.CosineAnnealingLR_Restart( + optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], + restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) + else: + print('MultiStepLR learning rate scheme is enough.') + + self.log_dict = OrderedDict() + + # val + if 'val' in opt: + self.heats = opt['val']['heats'] + self.n_sample = opt['val']['n_sample'] + self.sr_mode = opt['val']['sr_mode'] + + self.print_network() # print network + self.load() # load G and D if needed + + def init_model(self, scale=0.1): + # Common practise for initialization. + for layer in self.netG.modules(): + if isinstance(layer, nn.Conv2d): + init.kaiming_normal_(layer.weight, a=0, mode='fan_in') + layer.weight.data *= scale # for residual block + if layer.bias is not None: + layer.bias.data.zero_() + elif isinstance(layer, nn.Linear): + init.kaiming_normal_(layer.weight, a=0, mode='fan_in') + layer.weight.data *= scale + if layer.bias is not None: + layer.bias.data.zero_() + elif isinstance(layer, nn.BatchNorm2d): + init.constant_(layer.weight, 1) + init.constant_(layer.bias.data, 0.0) + + def feed_data(self, data, need_GT=True): + self.var_L = data['LQ'].to(self.device) # LQ + if need_GT: + self.real_H = data['GT'].to(self.device) # GT + + def optimize_parameters(self, step): + # special initialization for actnorm; don't initialize when fine-tuning + if step < self.opt['network_G']['act_norm_start_step'] and not (self.cri_pix_hr or self.cri_gan): + self.set_actnorm_init(inited=False) + + # (1) G + fake_H = None + if (step % self.D_update_ratio == 0 and step > self.D_init_iters) or (not self.cri_gan): + # normal flow + l_g_total = 0 + + _, nll = self.netG(hr=self.real_H, lr=self.var_L, u=None, reverse=False) + nll = self.l_nll_w * nll.sum() + self.log_dict['nll'] = nll.item() + if not torch.isnan(nll).any(): + l_g_total += nll + if l_g_total != 0: + l_g_total.backward() + self.gradient_clip() + self.optimizer_G.step() + + # reverse flow (optimize NLL loss and HR loss seperately (takes less memory and more time, slightly better results)) + self.optimizer_G.zero_grad() + if self.cri_pix_hr: + fake_H = self.netG(lr=self.var_L, z=None, u=None, eps_std=0.0, reverse=True) + l_g_total = 0 + if not torch.isnan(fake_H).any(): + # pixel loss + l_g_pix_hr = self.l_pix_w_hr * self.cri_pix_hr(fake_H, self.real_H) + l_g_total += l_g_pix_hr + self.log_dict['l_g_pix_hr'] = l_g_pix_hr.item() + if l_g_total != 0: + l_g_total.backward() + self.gradient_clip() + self.optimizer_G.step() + + + ######################## + + if self.cri_gan or self.cri_fea: + self.optimizer_G.zero_grad() + fake_H = self.netG(lr=self.var_L, z=None, u=None, eps_std=self.eps_std_reverse, reverse=True) + l_g_fea_gan = 0 + + # feature loss + if self.cri_fea: + real_fea = self.netF(self.real_H).detach() + fake_fea = self.netF(fake_H) + l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) + l_g_fea_gan += l_g_fea + self.log_dict['l_g_fea'] = l_g_fea.item() + + # gan loss + if self.cri_gan: + for p in self.netD.parameters(): + p.requires_grad = False + + pred_g_fake = self.netD(fake_H) + if self.opt['train']['gan_type'] in ['gan', 'lsgan', 'wgangp']: + l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) + elif self.opt['train']['gan_type'] == 'ragan': + pred_d_real = self.netD(self.real_H).detach() + l_g_gan = self.l_gan_w * ( self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 + l_g_fea_gan += l_g_gan + self.log_dict['l_g_gan'] = l_g_gan.item() + + if not torch.isnan(l_g_fea_gan): + l_g_fea_gan.backward() + self.gradient_clip() + self.optimizer_G.step() + + # (2) D + if self.cri_gan: + self.optimizer_G.zero_grad() # can help save memory + + for p in self.netD.parameters(): + p.requires_grad = True + + # initialize D + if fake_H is None: + with torch.no_grad(): + fake_H = self.netG(lr=self.var_L, z=None, u=None, eps_std=self.eps_std_reverse, reverse=True) + + self.optimizer_D.zero_grad() + pred_d_real = self.netD(self.real_H) + pred_d_fake = self.netD(fake_H.detach()) # detach to avoid BP to G + if self.opt['train']['gan_type'] in ['gan', 'lsgan', 'wgangp']: + l_d_real = self.cri_gan(pred_d_real, True) + l_d_fake = self.cri_gan(pred_d_fake, False) + l_d_total = l_d_real + l_d_fake + elif self.opt['train']['gan_type'] == 'ragan': + l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) + l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) + l_d_total = (l_d_real + l_d_fake) / 2 + + self.log_dict['l_d_real'] = l_d_real.item() + self.log_dict['l_d_fake'] = l_d_fake.item() + self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) + self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) + + if not torch.isnan(l_d_total): + l_d_total.backward() + self.optimizer_D.step() + + def gradient_clip(self): + # gradient clip & norm, is not used in SRFlow + if self.max_grad_clip is not None: + torch.nn.utils.clip_grad_value_(self.netG.parameters(), self.max_grad_clip) + if self.max_grad_norm is not None: + torch.nn.utils.clip_grad_norm_(self.netG.parameters(), self.max_grad_norm) + + def test(self): + self.netG.eval() + self.fake_H = {} + + with torch.no_grad(): + # hr->lr+z, calculate nll + self.fake_L_from_H, nll = self.netG(hr=self.real_H, lr=self.var_L, u=None, reverse=False, training=False) + # nll = torch.zeros(1) + + # lr+z->hr + for heat in self.heats: + for sample in range(self.n_sample): + # z = self.get_z(heat, seed=1, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape) + self.fake_H[(heat, sample)] = self.netG(lr=self.var_L, + z=None, u=None, eps_std=heat, reverse=True, training=False) + + self.netG.train() + + return nll.mean().item() + + def get_encode_nll(self, lq, hr, y_label=None): + self.netG.eval() + with torch.no_grad(): + _, nll, _ = self.netG(hr=hr, lr=lq, reverse=False, y_label=y_label) + self.netG.train() + return nll.mean().item() + + def get_sr(self, lq, heat=None, seed=None, z=None, epses=None, y_label=None): + return self.get_sr_with_z(lq, heat, seed, z, epses, y_label=y_label)[0] + + def get_encode_z(self, lq, hr, epses=None, add_gt_noise=True, y_label=None): + self.netG.eval() + with torch.no_grad(): + z, _, _ = self.netG(hr=hr, lr=lq, reverse=False, epses=epses, add_gt_noise=add_gt_noise, y_label=y_label) + self.netG.train() + return z + + def get_encode_z_and_nll(self, lq, hr, epses=None, y_label=None): + self.netG.eval() + with torch.no_grad(): + z, nll, _ = self.netG(hr=hr, lr=lq, reverse=False, epses=epses, y_label=y_label) + self.netG.train() + return z, nll + + def get_sr_with_z(self, lq, heat=None, seed=None, z=None, epses=None, y_label=None): + self.netG.eval() + + z = self.get_z(heat, seed, batch_size=lq.shape[0], lr_shape=lq.shape, + y_label=None) if z is None and epses is None else z + + with torch.no_grad(): + sr, logdet = self.netG(lr=lq, z=z, eps_std=heat, reverse=True, epses=epses, y_label=None) + self.netG.train() + return sr, z + + def get_z(self, heat, seed=None, batch_size=1, lr_shape=None, y_label=None): + if y_label is None: + pass + if seed: torch.manual_seed(seed) + if util.opt_get(self.opt, ['network_G', 'flowLR', 'splitOff', 'enable']): + C, H, W = lr_shape[1], lr_shape[2], lr_shape[3] + + size = (batch_size, C, H, W) + if heat == 0: + z = torch.zeros(size) + else: + z = torch.normal(mean=0, std=heat, size=size) + else: + L = util.opt_get(self.opt, ['network_G', 'flow', 'L']) or 3 + fac = 2 ** (L - 3) + z_size = int(self.lr_size // (2 ** (L - 3))) + z = torch.normal(mean=0, std=heat, size=(batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size)) + return z.to(self.device) + + def get_current_log(self): + return self.log_dict + + def get_current_visuals(self, need_GT=True): + out_dict = OrderedDict() + out_dict['LQ'] = self.var_L.detach()[0].float().cpu() + for heat in self.heats: + for i in range(self.n_sample): + out_dict[('SR', heat, i)] = self.fake_H[(heat, i)].detach()[0].float().cpu() + + if need_GT: + out_dict['GT'] = self.real_H.detach()[0].float().cpu() + out_dict['LQ_fromH'] = self.fake_L_from_H.detach()[0].float().cpu() + + return out_dict + + def print_network(self): + s, n = self.get_network_description(self.netG) + if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): + net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, + self.netG.module.__class__.__name__) + else: + net_struc_str = '{}'.format(self.netG.__class__.__name__) + if self.rank <= 0: + logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) + logger.info(s) + + if self.is_train: + # Discriminator + if self.cri_gan: + s, n = self.get_network_description(self.netD) + if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD, + DistributedDataParallel): + net_struc_str = '{} - {}'.format(self.netD.__class__.__name__, + self.netD.module.__class__.__name__) + else: + net_struc_str = '{}'.format(self.netD.__class__.__name__) + if self.rank <= 0: + logger.info('Network D structure: {}, with parameters: {:,d}'.format( + net_struc_str, n)) + logger.info(s) + + # F, Perceptual Network + if self.cri_fea: + s, n = self.get_network_description(self.netF) + if isinstance(self.netF, nn.DataParallel) or isinstance( + self.netF, DistributedDataParallel): + net_struc_str = '{} - {}'.format(self.netF.__class__.__name__, + self.netF.module.__class__.__name__) + else: + net_struc_str = '{}'.format(self.netF.__class__.__name__) + if self.rank <= 0: + logger.info('Network F structure: {}, with parameters: {:,d}'.format( + net_struc_str, n)) + logger.info(s) + + def load(self): + # resume training automatically if resume_state=='auto' + _, get_resume_model_path = util.get_resume_paths(self.opt) + if get_resume_model_path is not None: + logger.info('Automatically loading model for G [{:s}] ...'.format(get_resume_model_path)) + self.load_network(get_resume_model_path, self.netG, strict=True, submodule=None) + self.set_actnorm_init(inited=True) + + if self.is_train and self.cri_gan: + get_resume_model_path = get_resume_model_path.replace('_G.pth', '_D.pth') + logger.info('Automatically loading model for D [{:s}] ...'.format(get_resume_model_path)) + self.load_network(get_resume_model_path, self.netD, strict=True, submodule=None) + return + + # resume training according to given paths (pretrain path has been overrided by resume path) + if self.opt.get('path') is not None: + load_path_G = self.opt['path']['pretrain_model_G'] + if load_path_G is not None: + logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) + self.load_network(load_path_G, self.netG, self.opt['path'].get('strict_load', True)) + + self.set_actnorm_init(inited=True) + + if self.is_train and self.cri_gan: + load_path_D = self.opt['path']['pretrain_model_D'] + if load_path_D is not None: + logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) + self.load_network(load_path_D, self.netD, self.opt['path'].get('strict_load', True)) + + def save(self, iter_label): + self.save_network(self.netG, 'G', iter_label) + if self.cri_gan: + self.save_network(self.netD, 'D', iter_label) + + def set_actnorm_init(self, inited=True): + for name, m in self.netG.named_modules(): + if (m.__class__.__name__.find("ActNorm") >= 0): + m.inited = inited + + # from glow + def generate_z(self, img): + self.eval() + B = self.hparams.Train.batch_size + x = img.unsqueeze(0).repeat(B, 1, 1, 1).cuda() + z,_, _ = self(x) + self.train() + return z[0].detach().cpu().numpy() + + def generate_attr_deltaz(self, dataset): + assert "y_onehot" in dataset[0] + self.eval() + with torch.no_grad(): + B = self.hparams.Train.batch_size + N = len(dataset) + attrs_pos_z = [[0, 0] for _ in range(self.y_classes)] + attrs_neg_z = [[0, 0] for _ in range(self.y_classes)] + for i in tqdm(range(0, N, B)): + j = min([i + B, N]) + # generate z for data from i to j + xs = [dataset[k]["x"] for k in range(i, j)] + while len(xs) < B: + xs.append(dataset[0]["x"]) + xs = torch.stack(xs).cuda() + zs, _, _ = self(xs) + for k in range(i, j): + z = zs[k - i].detach().cpu().numpy() + # append to different attrs + y = dataset[k]["y_onehot"] + for ai in range(self.y_classes): + if y[ai] > 0: + attrs_pos_z[ai][0] += z + attrs_pos_z[ai][1] += 1 + else: + attrs_neg_z[ai][0] += z + attrs_neg_z[ai][1] += 1 + # break + deltaz = [] + for ai in range(self.y_classes): + if attrs_pos_z[ai][1] == 0: + attrs_pos_z[ai][1] = 1 + if attrs_neg_z[ai][1] == 0: + attrs_neg_z[ai][1] = 1 + z_pos = attrs_pos_z[ai][0] / float(attrs_pos_z[ai][1]) + z_neg = attrs_neg_z[ai][0] / float(attrs_neg_z[ai][1]) + deltaz.append(z_pos - z_neg) + self.train() + return deltaz diff --git a/codes/models/__init__.py b/codes/models/__init__.py new file mode 100644 index 0000000..9803258 --- /dev/null +++ b/codes/models/__init__.py @@ -0,0 +1,52 @@ +import importlib +import logging +import os + +try: + import local_config +except: + local_config = None + + +logger = logging.getLogger('base') + + +def find_model_using_name(model_name): + # Given the option --model [modelname], + # the file "models/modelname_model.py" + # will be imported. + model_filename = "models." + model_name + "_model" + modellib = importlib.import_module(model_filename) + + # In the file, the class called ModelNameModel() will + # be instantiated. It has to be a subclass of torch.nn.Module, + # and it is case-insensitive. + model = None + target_model_name = model_name.replace('_', '') + 'Model' + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower(): + model = cls + + if model is None: + print( + "In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s." % ( + model_filename, target_model_name)) + exit(0) + + return model + + +def create_model(opt, step=0, **opt_kwargs): + if local_config is not None: + opt['path']['pretrain_model_G'] = os.path.join(local_config.checkpoint_path, os.path.basename(opt['path']['results_root'] + '.pth')) + + for k, v in opt_kwargs.items(): + opt[k] = v + + model = opt['model'] + + M = find_model_using_name(model) + + m = M(opt, step) + logger.info('Model [{:s}] is created.'.format(m.__class__.__name__)) + return m diff --git a/codes/models/base_model.py b/codes/models/base_model.py new file mode 100644 index 0000000..e5af191 --- /dev/null +++ b/codes/models/base_model.py @@ -0,0 +1,166 @@ +import os +from collections import OrderedDict +import torch +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel +import natsort +import glob + + +class BaseModel(): + def __init__(self, opt): + self.opt = opt + self.device = torch.device('cuda' if opt.get('gpu_ids', None) is not None else 'cpu') + self.is_train = opt['is_train'] + self.schedulers = [] + self.optimizers = [] + + def feed_data(self, data): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + pass + + def get_current_losses(self): + pass + + def print_network(self): + pass + + def save(self, label): + pass + + def load(self): + pass + + def _set_lr(self, lr_groups_l): + ''' set learning rate for warmup, + lr_groups_l: list for lr_groups. each for a optimizer''' + for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): + for param_group, lr in zip(optimizer.param_groups, lr_groups): + param_group['lr'] = lr + + def _get_init_lr(self): + # get the initial lr, which is set by the scheduler + init_lr_groups_l = [] + for optimizer in self.optimizers: + init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) + return init_lr_groups_l + + def update_learning_rate(self, cur_iter, warmup_iter=-1): + for scheduler in self.schedulers: + scheduler.step() + #### set up warm up learning rate + if cur_iter < warmup_iter: + # get initial lr for each group + init_lr_g_l = self._get_init_lr() + # modify warming-up learning rates + warm_up_lr_l = [] + for init_lr_g in init_lr_g_l: + warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) + # set learning rate + self._set_lr(warm_up_lr_l) + + def get_current_learning_rate(self): + # return self.schedulers[0].get_lr()[0] + return self.optimizers[0].param_groups[0]['lr'] + + def get_network_description(self, network): + '''Get the string and total parameters of the network''' + if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): + network = network.module + s = str(network) + n = sum(map(lambda x: x.numel(), network.parameters())) + return s, n + + def save_network(self, network, network_label, iter_label): + paths = natsort.natsorted(glob.glob(os.path.join(self.opt['path']['models'], "*_{}.pth".format(network_label))), + reverse=True) + paths = [p for p in paths if + "latest_" not in p and not any([str(i * 5000) in p.split("/")[-1].split("_") for i in range(101)])] + if len(paths) > 2: + for path in paths[2:]: + os.remove(path) + save_filename = '{}_{}.pth'.format(iter_label, network_label) + save_path = os.path.join(self.opt['path']['models'], save_filename) + if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): + network = network.module + state_dict = network.state_dict() + for key, param in state_dict.items(): + state_dict[key] = param.cpu() + torch.save(state_dict, save_path) + + def load_network(self, load_path, network, strict=True, submodule=None): + if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): + network = network.module + if not (submodule is None or submodule.lower() == 'none'.lower()): + network = network.__getattr__(submodule) + load_net = torch.load(load_path) + load_net_clean = OrderedDict() # remove unnecessary 'module.' + for k, v in load_net.items(): + + # if 'flowLRNet' in k or 'quantization' in k: + # continue + # k = k.replace('flowDownpsamplerNet', 'flow') + # + # # step2 + # k = k.replace('Split2d712', 'level0_condFlow') + # k = k.replace('layers.29', 'level1_condFlow') + # # k = k.replace('Split2d932', 'level2_condFlow') + + + if k.startswith('module.'): + load_net_clean[k[7:]] = v + else: + load_net_clean[k] = v + + network.load_state_dict(load_net_clean, strict=strict) + + # if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): + # network = network.module + # state_dict = network.state_dict() + # for key, param in state_dict.items(): + # state_dict[key] = param.cpu() + # torch.save(state_dict, '../experiments/SR_DF2K_X4_HCFlow32.pth') + + def save_training_state(self, epoch, iter_step): + '''Saves training state during training, which will be used for resuming''' + state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []} + for s in self.schedulers: + state['schedulers'].append(s.state_dict()) + for o in self.optimizers: + state['optimizers'].append(o.state_dict()) + save_filename = '{}.state'.format(iter_step) + save_path = os.path.join(self.opt['path']['training_state'], save_filename) + + paths = natsort.natsorted(glob.glob(os.path.join(self.opt['path']['training_state'], "*.state")), + reverse=True) + paths = [p for p in paths if "latest_" not in p] + if len(paths) > 2: + for path in paths[2:]: + os.remove(path) + + torch.save(state, save_path) + + def resume_training(self, resume_state): + '''Resume the optimizers and schedulers for training''' + resume_optimizers = resume_state['optimizers'] + resume_schedulers = resume_state['schedulers'] + assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' + assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' + for i, o in enumerate(resume_optimizers): + self.optimizers[i].load_state_dict(o) + for i, s in enumerate(resume_schedulers): + # manually change lr milestones + # from collections import Counter + # s['milestones'] = Counter([100000, 150000, 200000, 250000, 300000, 350000, 400000]) # for multistage lr + # s['restarts'] = [120001, 240001, 360001] # for cosine_restart_lr + + self.schedulers[i].load_state_dict(s) + + + + diff --git a/codes/models/lr_scheduler.py b/codes/models/lr_scheduler.py new file mode 100644 index 0000000..2304f0c --- /dev/null +++ b/codes/models/lr_scheduler.py @@ -0,0 +1,144 @@ +import math +from collections import Counter +from collections import defaultdict +import torch +from torch.optim.lr_scheduler import _LRScheduler + + +class MultiStepLR_Restart(_LRScheduler): + def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, + clear_state=False, last_epoch=-1): + self.milestones = Counter(milestones) + self.gamma = gamma + self.clear_state = clear_state + self.restarts = restarts if restarts else [0] + self.restarts = [v + 1 for v in self.restarts] + self.restart_weights = weights if weights else [1] + assert len(self.restarts) == len( + self.restart_weights), 'restarts and their weights do not match.' + super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch in self.restarts: + if self.clear_state: + self.optimizer.state = defaultdict(dict) + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [group['initial_lr'] * weight for group in self.optimizer.param_groups] + if self.last_epoch not in self.milestones: + return [group['lr'] for group in self.optimizer.param_groups] + return [ + group['lr'] * self.gamma**self.milestones[self.last_epoch] + for group in self.optimizer.param_groups + ] + + +class CosineAnnealingLR_Restart(_LRScheduler): + def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1): + self.T_period = T_period + self.T_max = self.T_period[0] # current T period + self.eta_min = eta_min + self.restarts = restarts if restarts else [0] + self.restarts = [v + 1 for v in self.restarts] + self.restart_weights = weights if weights else [1] + self.last_restart = 0 + assert len(self.restarts) == len( + self.restart_weights), 'restarts and their weights do not match.' + super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch == 0: + return self.base_lrs + elif self.last_epoch in self.restarts: + self.last_restart = self.last_epoch + self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [group['initial_lr'] * weight for group in self.optimizer.param_groups] + elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0: + return [ + group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / + (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) * + (group['lr'] - self.eta_min) + self.eta_min + for group in self.optimizer.param_groups] + + +if __name__ == "__main__": + optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0, + betas=(0.9, 0.99)) + ############################## + # MultiStepLR_Restart + ############################## + ## Original + lr_steps = [200000, 400000, 600000, 800000] + restarts = None + restart_weights = None + + ## two + lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000] + restarts = [500000] + restart_weights = [1] + + ## four + lr_steps = [ + 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000, + 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000 + ] + restarts = [250000, 500000, 750000] + restart_weights = [1, 1, 1] + + scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5, + clear_state=False) + + ############################## + # Cosine Annealing Restart + ############################## + ## two + T_period = [500000, 500000] + restarts = [500000] + restart_weights = [1] + + ## four + T_period = [250000, 250000, 250000, 250000] + restarts = [250000, 500000, 750000] + restart_weights = [1,1, 1] + + scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts, + weights=restart_weights) + + ############################## + # Draw figure + ############################## + N_iter = 1000000 + lr_l = list(range(N_iter)) + for i in range(N_iter): + scheduler.step() + current_lr = optimizer.param_groups[0]['lr'] + lr_l[i] = current_lr + + import matplotlib as mpl + from matplotlib import pyplot as plt + import matplotlib.ticker as mtick + mpl.style.use('default') + import seaborn + seaborn.set(style='whitegrid') + seaborn.set_context('paper') + + plt.figure(1) + plt.subplot(111) + plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0)) + plt.title('Title', fontsize=16, color='k') + plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme') + legend = plt.legend(loc='upper right', shadow=False) + ax = plt.gca() + labels = ax.get_xticks().tolist() + for k, v in enumerate(labels): + labels[k] = str(int(v / 1000)) + 'K' + ax.set_xticklabels(labels) + ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) + + ax.set_ylabel('Learning rate') + ax.set_xlabel('Iteration') + fig = plt.gcf() + plt.show() diff --git a/codes/models/modules/ActNorms.py b/codes/models/modules/ActNorms.py new file mode 100644 index 0000000..aa81f14 --- /dev/null +++ b/codes/models/modules/ActNorms.py @@ -0,0 +1,122 @@ +import torch +from torch import nn as nn + +from models.modules import thops + + +class _ActNorm(nn.Module): + """ + Activation Normalization + Initialize the bias and scale with a given minibatch, + so that the output per-channel have zero mean and unit variance for that. + + After initialization, `bias` and `logs` will be trained as parameters. + """ + + def __init__(self, num_features, scale=1.): + super().__init__() + # register mean and scale + size = [1, num_features, 1, 1] + self.register_parameter("bias", nn.Parameter(torch.zeros(*size))) + self.register_parameter("logs", nn.Parameter(torch.zeros(*size))) + self.num_features = num_features + self.scale = float(scale) + self.inited = False + + def _check_input_dim(self, input): + return NotImplemented + + def initialize_parameters(self, input): + self._check_input_dim(input) + if not self.training: + return + if (self.bias != 0).any(): + self.inited = True + return + assert input.device == self.bias.device, (input.device, self.bias.device) + with torch.no_grad(): + bias = thops.mean(input.clone(), dim=[0, 2, 3], keepdim=True) * -1.0 + vars = thops.mean((input.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True) + logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6)) + self.bias.data.copy_(bias.data) + self.logs.data.copy_(logs.data) + self.inited = True + + def _center(self, input, reverse=False, offset=None): + bias = self.bias + + if offset is not None: + bias = bias + offset + + if not reverse: + return input + bias + else: + return input - bias + + def _scale(self, input, logdet=None, reverse=False, offset=None): + logs = self.logs + + if offset is not None: + logs = logs + offset + + if not reverse: + input = input * torch.exp(logs) # should have shape batchsize, n_channels, 1, 1 + # input = input * torch.exp(logs+logs_offset) + else: + input = input * torch.exp(-logs) + if logdet is not None: + """ + logs is log_std of `mean of channels` + so we need to multiply pixels + """ + dlogdet = thops.sum(logs) * thops.pixels(input) + if reverse: + dlogdet *= -1 + logdet = logdet + dlogdet + return input, logdet + + def forward(self, input, logdet=None, reverse=False, offset_mask=None, logs_offset=None, bias_offset=None): + if not self.inited: + self.initialize_parameters(input) + + if offset_mask is not None: + logs_offset *= offset_mask + bias_offset *= offset_mask + # no need to permute dims as old version + if not reverse: + # center and scale + input = self._center(input, reverse, bias_offset) + input, logdet = self._scale(input, logdet, reverse, logs_offset) + else: + # scale and center + input, logdet = self._scale(input, logdet, reverse, logs_offset) + input = self._center(input, reverse, bias_offset) + return input, logdet + + +class ActNorm2d(_ActNorm): + def __init__(self, num_features, scale=1.): + super().__init__(num_features, scale) + + def _check_input_dim(self, input): + assert len(input.size()) == 4 + assert input.size(1) == self.num_features, ( + "[ActNorm]: input should be in shape as `BCHW`," + " channels should be {} rather than {}".format( + self.num_features, input.size())) + + +class MaskedActNorm2d(ActNorm2d): + def __init__(self, num_features, scale=1.): + super().__init__(num_features, scale) + + def forward(self, input, mask, logdet=None, reverse=False): + + assert mask.dtype == torch.bool + output, logdet_out = super().forward(input, logdet, reverse) + + input[mask] = output[mask] + logdet[mask] = logdet_out[mask] + + return input, logdet + diff --git a/codes/models/modules/AffineCouplings.py b/codes/models/modules/AffineCouplings.py new file mode 100644 index 0000000..2c9b840 --- /dev/null +++ b/codes/models/modules/AffineCouplings.py @@ -0,0 +1,224 @@ +import torch +from torch import nn as nn +import torch.nn.functional as F + +from models.modules import thops +from models.modules.Basic import Conv2d, Conv2dZeros, DenseBlock, FCN, RDN +from utils.util import opt_get, register_hook, trunc_normal_ + + +class AffineCoupling(nn.Module): + def __init__(self, in_channels, cond_channels=None, opt=None): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = opt_get(opt, ['hidden_channels'], 64) + self.n_hidden_layers = 1 + self.kernel_hidden = 1 + self.cond_channels = cond_channels + f_in_channels = self.in_channels//2 if cond_channels is None else self.in_channels//2 + cond_channels + f_out_channels = (self.in_channels - self.in_channels//2) * 2 + nn_module = opt_get(opt, ['nn_module'], 'FCN') + if nn_module == 'DenseBlock': + self.f = DenseBlock(in_channels=f_in_channels, out_channels=f_out_channels, gc=self.hidden_channels) + elif nn_module == 'FCN': + self.f = FCN(in_channels=f_in_channels, out_channels=f_out_channels, hidden_channels=self.hidden_channels, + kernel_hidden=self.kernel_hidden, n_hidden_layers=self.n_hidden_layers) + + + def forward(self, z, u=None, y=None, logdet=None, reverse=False): + if not reverse: + return self.normal_flow(z, u, y, logdet) + else: + return self.reverse_flow(z, u, y, logdet) + + def normal_flow(self, z, u=None, y=None, logdet=None): + z1, z2 = thops.split_feature(z, "split") + + h = self.f(z1) if self.cond_channels is None else self.f(thops.cat_feature(z1, u)) + shift, scale = thops.split_feature(h, "cross") + # adding 1e-4 is crucial for torch.slogdet(), as used in Glow (leads to black rect in experiments). + # see https://github.com/didriknielsen/survae_flows/issues/5 for discussion. + # or use `torch.exp(2. * torch.tanh(s / 2.)) as in SurVAE (more unstable in practice). + + # version 1, srflow (use FCN) + # scale = torch.sigmoid(scale + 2.) + 1e-4 + # z2 = (z2 + shift) * scale + # logdet += thops.sum(torch.log(scale), dim=[1, 2, 3]) + + # version2, survae + # logscale = 2. * torch.tanh(scale / 2.) + # z2 = (z2+shift) * torch.exp(logscale) # as in glow, it's shift+scale! + # logdet += thops.sum(logscale, dim=[1, 2, 3]) + + # version3, FrEIA, now have problem with FCN, but densenet is ok. (use FCN2/Denseblock) + # logscale = 0.5 * 0.636 * torch.atan(scale / 0.5) # clamp it to be between [-0.5,0.5] + logscale = 0.318 * torch.atan(2 * scale) + # logscale = 1.0 * 0.636 * torch.atan(scale / 1.0) + z2 = (z2 + shift) * torch.exp(logscale) + if logdet is not None: + logdet += thops.sum(logscale, dim=[1, 2, 3]) + + z = thops.cat_feature(z1, z2) + + return z, logdet + + def reverse_flow(self, z, u=None, y=None, logdet=None): + z1, z2 = thops.split_feature(z, "split") + + h = self.f(z1) if self.cond_channels is None else self.f(thops.cat_feature(z1, u)) + shift, scale = thops.split_feature(h, "cross") + + # version1, srflow + # scale = torch.sigmoid(scale + 2.) + 1e-4 + # z2 = (z2 / scale) -shift + + # version2, survae + # logscale = 2. * torch.tanh(scale / 2.) + # z2 = z2 * torch.exp(-logscale) - shift + + # version3, FrEIA + # logscale = 0.5 * 0.636 * torch.atan(scale / 0.5) + logscale = 0.318 * torch.atan(2 * scale) + # logscale = 1 * 0.636 * torch.atan(scale / 1.0) + z2 = z2 * torch.exp(-logscale) - shift + + z = thops.cat_feature(z1, z2) + + return z, logdet + + +'''3 channel conditional on the rest channels, or vice versa. only shift LR. + used in image rescaling to divide the low-frequencies and the high-frequencies apart from early flow layers.''' +class AffineCoupling3shift(nn.Module): + def __init__(self, in_channels, cond_channels=None, LRvsothers=True, opt=None): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = opt_get(opt, ['hidden_channels'], 64) + self.n_hidden_layers = 1 + self.kernel_hidden = 1 + self.cond_channels = cond_channels + self.LRvsothers = LRvsothers + if LRvsothers: + f_in_channels = 3 if cond_channels is None else 3 + cond_channels + f_out_channels = (self.in_channels - 3) * 2 + else: + f_in_channels = self.in_channels - 3 if cond_channels is None else self.in_channels - 3 + cond_channels + f_out_channels = 3 + nn_module = opt_get(opt, ['nn_module'], 'FCN') + + if nn_module == 'DenseBlock': + self.f = DenseBlock(in_channels=f_in_channels, out_channels=f_out_channels, gc=self.hidden_channels) + elif nn_module == 'FCN': + self.f = FCN(in_channels=f_in_channels, out_channels=f_out_channels, hidden_channels=self.hidden_channels, + kernel_hidden=self.kernel_hidden, n_hidden_layers=self.n_hidden_layers) + + + def forward(self, z, u=None, y=None, logdet=None, reverse=False): + if not reverse: + return self.normal_flow(z, u, y, logdet) + else: + return self.reverse_flow(z, u, y, logdet) + + def normal_flow(self, z, u=None, y=None, logdet=None): + if self.LRvsothers: + z1, z2 = z[:, :3, ...], z[:, 3:, ...] + h = self.f(z1) if self.cond_channels is None else self.f(thops.cat_feature(z1, u)) + shift, scale = thops.split_feature(h, "cross") + logscale = 0.318 * torch.atan(2 * scale) + z2 = (z2 + shift) * torch.exp(logscale) + if logdet is not None: + logdet += thops.sum(logscale, dim=[1, 2, 3]) + else: + z2, z1 = z[:, :3, ...], z[:, 3:, ...] + shift = self.f(z1) if self.cond_channels is None else self.f(thops.cat_feature(z1, u)) + z2 = z2 + shift + + if self.LRvsothers: + z = thops.cat_feature(z1, z2) + else: + z = thops.cat_feature(z2, z1) + + return z, logdet + + def reverse_flow(self, z, u=None, y=None, logdet=None): + if self.LRvsothers: + z1, z2 = z[:, :3, ...], z[:, 3:, ...] + h = self.f(z1) if self.cond_channels is None else self.f(thops.cat_feature(z1, u)) + shift, scale = thops.split_feature(h, "cross") + logscale = 0.318 * torch.atan(2 * scale) + z2 = z2 * torch.exp(-logscale) - shift + else: + z2, z1 = z[:, :3, ...], z[:, 3:, ...] + shift = self.f(z1) + z2 = z2 - shift + + if self.LRvsothers: + z = thops.cat_feature(z1, z2) + else: + z = thops.cat_feature(z2, z1) + + return z, logdet + + +''' srflow's affine injector + original affine coupling, not used in this project''' +class AffineCouplingInjector(nn.Module): + def __init__(self, in_channels, cond_channels=None, opt=None): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = opt_get(opt, ['hidden_channels'], 64) + self.n_hidden_layers = 1 + self.kernel_hidden = 1 + self.cond_channels = cond_channels + f_in_channels = self.in_channels//2 if cond_channels is None else self.in_channels//2 + cond_channels + f_out_channels = (self.in_channels - self.in_channels//2) * 2 + nn_module = opt_get(opt, ['nn_module'], 'FCN') + if nn_module == 'DenseBlock': + self.f = DenseBlock(in_channels=f_in_channels, out_channels=f_out_channels, gc=self.hidden_channels) + self.f_injector = DenseBlock(in_channels=cond_channels, out_channels=self.in_channels*2, gc=self.hidden_channels) + elif nn_module == 'FCN': + self.f = FCN(in_channels=f_in_channels, out_channels=f_out_channels, hidden_channels=self.hidden_channels, + kernel_hidden=self.kernel_hidden, n_hidden_layers=self.n_hidden_layers) + self.f_injector = FCN(in_channels=cond_channels, out_channels=self.in_channels*2, hidden_channels=self.hidden_channels, + kernel_hidden=self.kernel_hidden, n_hidden_layers=self.n_hidden_layers) + + def forward(self, z, u=None, y=None, logdet=None, reverse=False): + if not reverse: + return self.normal_flow(z, u, y, logdet) + else: + return self.reverse_flow(z, u, y, logdet) + + def normal_flow(self, z, u=None, y=None, logdet=None): + # overall-conditional + h = self.f_injector(u) + shift, scale = thops.split_feature(h, "cross") + logscale = 0.318 * torch.atan(2 * scale) # clamp it to be between [-5,5] + z = (z + shift) * torch.exp(logscale) + logdet += thops.sum(logscale, dim=[1, 2, 3]) + + # self-conditional + z1, z2 = thops.split_feature(z, "split") + h = self.f(z1) if self.cond_channels is None else self.f(thops.cat_feature(z1, u)) + shift, scale = thops.split_feature(h, "cross") + logscale = 0.318 * torch.atan(2 * scale) # clamp it to be between [-5,5] + z2 = (z2 + shift) * torch.exp(logscale) + logdet += thops.sum(logscale, dim=[1, 2, 3]) + z = thops.cat_feature(z1, z2) + + return z, logdet + + def reverse_flow(self, z, u=None, y=None, logdet=None): + # self-conditional + z1, z2 = thops.split_feature(z, "split") + h = self.f(z1) if self.cond_channels is None else self.f(thops.cat_feature(z1, u)) + shift, scale = thops.split_feature(h, "cross") + logscale = 0.318 * torch.atan(2 * scale) + z2 = z2 * torch.exp(-logscale) - shift + z = thops.cat_feature(z1, z2) + + # overall-conditional + h = self.f_injector(u) + shift, scale = thops.split_feature(h, "cross") + logscale = 0.318 * torch.atan(2 * scale) + z = z * torch.exp(-logscale) - shift + + return z, logdet diff --git a/codes/models/modules/Basic.py b/codes/models/modules/Basic.py new file mode 100644 index 0000000..aa5897b --- /dev/null +++ b/codes/models/modules/Basic.py @@ -0,0 +1,500 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import functools + +from models.modules.ActNorms import ActNorm2d +from . import thops + +from utils.util import opt_get +import models.modules.module_util as mutil + + +class Conv2d(nn.Conv2d): + pad_dict = { + "same": lambda kernel, stride: [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)], + "valid": lambda kernel, stride: [0 for _ in kernel] + } + + @staticmethod + def get_padding(padding, kernel_size, stride): + # make paddding + if isinstance(padding, str): + if isinstance(kernel_size, int): + kernel_size = [kernel_size, kernel_size] + if isinstance(stride, int): + stride = [stride, stride] + padding = padding.lower() + try: + padding = Conv2d.pad_dict[padding](kernel_size, stride) + except KeyError: + raise ValueError("{} is not supported".format(padding)) + return padding + + def __init__(self, in_channels, out_channels, + kernel_size=[3, 3], stride=[1, 1], padding="same", groups=1, + do_actnorm=True, weight_std=0.05): + padding = Conv2d.get_padding(padding, kernel_size, stride) + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, groups=groups, bias=(not do_actnorm)) + # init weight with std + self.weight.data.normal_(mean=0.0, std=weight_std) + if not do_actnorm: + self.bias.data.zero_() + else: + self.actnorm = ActNorm2d(out_channels) + self.do_actnorm = do_actnorm + + def forward(self, input): + x = super().forward(input) + if self.do_actnorm: + x, _ = self.actnorm(x) + return x + +# Zero initialization. We initialize the last convolution of each NN() with zeros, such that each affine +# coupling layer initially performs an identity function; we found that this helps training very deep networks. +class Conv2dZeros(nn.Conv2d): + def __init__(self, in_channels, out_channels, + kernel_size=[3, 3], stride=[1, 1], + padding="same", logscale_factor=3): + padding = Conv2d.get_padding(padding, kernel_size, stride) + super().__init__(in_channels, out_channels, kernel_size, stride, padding) + # logscale_factor + self.logscale_factor = logscale_factor + self.register_parameter("logs", nn.Parameter(torch.zeros(out_channels, 1, 1))) + # init + self.weight.data.zero_() + self.bias.data.zero_() + + def forward(self, input): + output = super().forward(input) + return output * torch.exp(self.logs * self.logscale_factor) + + +class GaussianDiag: + Log2PI = float(np.log(2 * np.pi)) + + @staticmethod + def likelihood(mean, logs, x): # logs: log(sigma) + """ + lnL = -1/2 * { ln|Var| + ((X - Mu)^T)(Var^-1)(X - Mu) + kln(2*PI) } + k = 1 (Independent) + Var = logs ** 2 + """ + if mean is None and logs is None: + return -0.5 * (x ** 2 + GaussianDiag.Log2PI) + else: + return -0.5 * (logs * 2. + ((x - mean) ** 2) / torch.exp(logs * 2.) + GaussianDiag.Log2PI) + + @staticmethod + def logp(mean, logs, x): + likelihood = GaussianDiag.likelihood(mean, logs, x) + return thops.sum(likelihood, dim=[1, 2, 3]) + + @staticmethod + def sample(mean, logs, eps_std=None): + # eps_std = eps_std or 1 # may cause problem when eps_std is 0 + eps = torch.normal(mean=torch.zeros_like(mean), + std=torch.ones_like(logs) * eps_std) + return mean + torch.exp(logs) * eps + + @staticmethod + def sample_eps(shape, eps_std, seed=None): + if seed is not None: + torch.manual_seed(seed) + eps = torch.normal(mean=torch.zeros(shape), + std=torch.ones(shape) * eps_std) + return eps + + +class LaplaceDiag: + Log2= float(np.log(2)) + + @staticmethod + def likelihood(mean, logs, x): # logs: log(sigma) + if mean is None and logs is None: + return - (torch.abs(x) + LaplaceDiag.Log2) + else: + return - (logs + (torch.abs(x - mean)) / torch.exp(logs) + LaplaceDiag.Log2) + + @staticmethod + def logp(mean, logs, x): + likelihood = LaplaceDiag.likelihood(mean, logs, x) + return thops.sum(likelihood, dim=[1, 2, 3]) + + +def squeeze2d(input, factor=2): + assert factor >= 1 and isinstance(factor, int) + if factor == 1: + return input + size = input.size() + B = size[0] + C = size[1] + H = size[2] + W = size[3] + assert H % factor == 0 and W % factor == 0, "{}".format((H, W, factor)) + x = input.view(B, C, H // factor, factor, W // factor, factor) + x = x.permute(0, 1, 3, 5, 2, 4).contiguous() + x = x.view(B, C * factor * factor, H // factor, W // factor) + return x + + +def unsqueeze2d(input, factor=2): + assert factor >= 1 and isinstance(factor, int) + factor2 = factor ** 2 + if factor == 1: + return input + size = input.size() + B = size[0] + C = size[1] + H = size[2] + W = size[3] + assert C % (factor2) == 0, "{}".format(C) + x = input.view(B, C // factor2, factor, factor, H, W) + x = x.permute(0, 1, 4, 2, 5, 3).contiguous() + x = x.view(B, C // (factor2), H * factor, W * factor) + return x + + +class SqueezeLayer(nn.Module): + def __init__(self, factor): + super().__init__() + self.factor = factor + + def forward(self, input, logdet=None, reverse=False): + if not reverse: + output = squeeze2d(input, self.factor) + return output, logdet + else: + output = unsqueeze2d(input, self.factor) + return output, logdet + +class UnSqueezeLayer(nn.Module): + def __init__(self, factor): + super().__init__() + self.factor = factor + + def forward(self, input, logdet=None, reverse=False): + if not reverse: + output = unsqueeze2d(input, self.factor) + return output, logdet + else: + output = squeeze2d(input, self.factor) + return output, logdet + +class Quant(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + input = torch.clamp(input, 0, 1) + output = (input * 255.).round() / 255. + return output + + @staticmethod + def backward(ctx, grad_output): + return grad_output + +class Quantization(nn.Module): + def __init__(self): + super(Quantization, self).__init__() + + def forward(self, input): + return Quant.apply(input) + +class Sigmoid(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, logdet=None, reverse=False): + if not reverse: + output = torch.sigmoid(input) + logdet += -thops.sum(F.softplus(input)+F.softplus(-input), dim=[1, 2, 3]) + return output, logdet + else: + output = -torch.log(torch.reciprocal(input) - 1.) + logdet += -thops.sum(torch.log(input) + torch.log(1.-input), dim=[1, 2, 3]) + return output, logdet + +# used in SRFlow +class Split2d_conditional(nn.Module): + def __init__(self, num_channels, logs_eps=0, cond_channels=0, position=None, consume_ratio=0.5, opt=None): + super().__init__() + + self.num_channels_consume = int(round(num_channels * consume_ratio)) + self.num_channels_pass = num_channels - self.num_channels_consume + + self.conv = Conv2dZeros(in_channels=self.num_channels_pass + cond_channels, + out_channels=self.num_channels_consume * 2) + self.logs_eps = logs_eps + self.position = position + self.opt = opt + + def split2d_prior(self, z, ft): + if ft is not None: + z = torch.cat([z, ft], dim=1) + h = self.conv(z) + return thops.split_feature(h, "cross") + + def exp_eps(self, logs): + return torch.exp(logs) + self.logs_eps + + def forward(self, input, logdet=0., reverse=False, eps_std=None, eps=None, ft=None, y_onehot=None): + if not reverse: + # self.input = input + z1, z2 = self.split_ratio(input) + mean, logs = self.split2d_prior(z1, ft) + + eps = (z2 - mean) / self.exp_eps(logs) + + logdet = logdet + self.get_logdet(logs, mean, z2) + + # print(logs.shape, mean.shape, z2.shape) + # self.eps = eps + # print('split, enc eps:', eps) + return z1, logdet, eps + else: + z1 = input + mean, logs = self.split2d_prior(z1, ft) + + if eps is None: + #print("WARNING: eps is None, generating eps untested functionality!") + eps = GaussianDiag.sample_eps(mean.shape, eps_std) + + eps = eps.to(mean.device) + z2 = mean + self.exp_eps(logs) * eps + + z = thops.cat_feature(z1, z2) + logdet = logdet - self.get_logdet(logs, mean, z2) + + return z, logdet + # return z, logdet, eps + + def get_logdet(self, logs, mean, z2): + logdet_diff = GaussianDiag.logp(mean, logs, z2) + # print("Split2D: logdet diff", logdet_diff.item()) + return logdet_diff + + def split_ratio(self, input): + z1, z2 = input[:, :self.num_channels_pass, ...], input[:, self.num_channels_pass:, ...] + return z1, z2 + + +''' Not used anymore ''' +class Split2d(nn.Module): + def __init__(self, num_channels): + super().__init__() + self.conv = Conv2dZeros(num_channels // 2, num_channels) + + def split2d_prior(self, z): + h = self.conv(z) + return thops.split_feature(h, "cross") + + def forward(self, input, logdet=0., reverse=False, eps_std=None): + if not reverse: + z1, z2 = thops.split_feature(input, "split") + mean, logs = self.split2d_prior(z1) + logdet = GaussianDiag.logp(mean, logs, z2) + logdet + return z1, logdet + else: + z1 = input + mean, logs = self.split2d_prior(z1) + z2 = GaussianDiag.sample(mean, logs, eps_std) + z = thops.cat_feature(z1, z2) + return z, logdet + +class Split2d_LR(nn.Module): + def __init__(self, num_channels, num_channels_split): + super().__init__() + self.num_channels_split = num_channels_split + self.conv = Conv2dZeros(num_channels_split, (num_channels-num_channels_split)*2) + + def split2d_prior(self, z): + h = self.conv(z) + return thops.split_feature(h, "cross") + + def forward(self, input, eps_std=None, logdet=0., reverse=False): + if not reverse: + z1, z2 = input[:, :self.num_channels_split, ...], input[:, self.num_channels_split:, ...] + mean, logs = self.split2d_prior(z1) + logdet += GaussianDiag.logp(mean, logs, z2) + return z1, logdet + else: + z1 = input + mean, logs = self.split2d_prior(z1) + z2 = GaussianDiag.sample(mean, logs, eps_std) + z = torch.cat((z1, z2), dim=1) + return z, logdet + +# DenseBlock for affine coupling (flow) +class DenseBlock(nn.Module): + def __init__(self, in_channels, out_channels, gc=32, bias=True, init='xavier', for_flow=True): + super(DenseBlock, self).__init__() + self.conv1 = nn.Conv2d(in_channels, gc, 3, 1, 1, bias=bias) + self.conv2 = nn.Conv2d(in_channels + gc, gc, 3, 1, 1, bias=bias) + self.conv3 = nn.Conv2d(in_channels + 2 * gc, gc, 3, 1, 1, bias=bias) + self.conv4 = nn.Conv2d(in_channels + 3 * gc, gc, 3, 1, 1, bias=bias) + self.conv5 = nn.Conv2d(in_channels + 4 * gc, out_channels, 3, 1, 1, bias=bias) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # init as 'xavier', following the practice in https://github.com/VLL-HD/FrEIA/blob/c5fe1af0de8ce9122b5b61924ad75a19b9dc2473/README.rst#useful-tips--engineering-heuristics + if init == 'xavier': + mutil.initialize_weights_xavier([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + else: + mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + # initialiize input to all zeros to have zero mean and unit variance + if for_flow: + mutil.initialize_weights(self.conv5, 0) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + + return x5 + + +# ResidualDenseBlock for multi-layer feature extraction +class ResidualDenseBlock(nn.Module): + def __init__(self, nf=64, gc=32, bias=True, init='xavier'): + super(ResidualDenseBlock, self).__init__() + # gc: growth channel, i.e. intermediate channels + self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) + self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) + self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) + self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) + self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # init as 'xavier', following the practice in https://github.com/VLL-HD/FrEIA/blob/c5fe1af0de8ce9122b5b61924ad75a19b9dc2473/README.rst#useful-tips--engineering-heuristics + if init == 'xavier': + mutil.initialize_weights_xavier([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + else: + mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + return x5 * 0.2 + x # residual scaling are helpful + +class RRDB(nn.Module): + '''Residual in Residual Dense Block''' + + def __init__(self, nf, gc=32): + super(RRDB, self).__init__() + self.RDB1 = ResidualDenseBlock(nf, gc) + self.RDB2 = ResidualDenseBlock(nf, gc) + self.RDB3 = ResidualDenseBlock(nf, gc) + + def forward(self, x): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + return out * 0.2 + x # residual scaling are helpful + +class RDN(nn.Module): + '''composed of rrdb blocks''' + + def __init__(self, in_channels, out_channels, nb=3, nf=64, gc=32, init='xavier', for_flow=True): + super(RDN, self).__init__() + + RRDB_f = functools.partial(RRDB, nf=nf, gc=gc) + self.conv_first = nn.Conv2d(in_channels, nf, 3, 1, 1, bias=True) + self.RRDB_trunk = mutil.make_layer(RRDB_f, nb) + self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv_last = nn.Conv2d(nf, out_channels, 3, 1, 1, bias=True) + + if init == 'xavier': + mutil.initialize_weights_xavier([self.conv_first, self.trunk_conv, self.conv_last], 0.1) + else: + mutil.initialize_weights([self.conv_first, self.trunk_conv, self.conv_last], 0.1) + + if for_flow: + mutil.initialize_weights(self.conv_last, 0) + + def forward(self, x): + x = self.conv_first(x) + x = self.trunk_conv(self.RRDB_trunk(x)) + x + return self.conv_last(x) + + +class FCN(nn.Module): + def __init__(self, in_channels, out_channels, hidden_channels, kernel_hidden=1, n_hidden_layers=1, init='xavier', for_flow=True): + super(FCN, self).__init__() + self.conv1 = Conv2d(in_channels, hidden_channels, kernel_size=[3, 3], stride=[1, 1]) + self.conv2 = Conv2d(hidden_channels, hidden_channels, kernel_size=[kernel_hidden, kernel_hidden]) + self.conv3 = Conv2dZeros(hidden_channels, out_channels, kernel_size=[3, 3], stride=[1, 1]) + self.relu = nn.ReLU(inplace=False) + + if init == 'xavier': + mutil.initialize_weights_xavier([self.conv1, self.conv2, self.conv3], 0.1) + else: + mutil.initialize_weights([self.conv1, self.conv2, self.conv3], 0.1) + + if for_flow: + mutil.initialize_weights(self.conv3, 0) + + def forward(self, x): + x = self.relu(self.conv1(x)) + x = self.relu(self.conv2(x)) + x = self.conv3(x) + + return x + + +class HaarDownsampling(nn.Module): + def __init__(self, channel_in): + super(HaarDownsampling, self).__init__() + self.channel_in = channel_in + + self.haar_weights = torch.ones(4, 1, 2, 2) + + self.haar_weights[1, 0, 0, 1] = -1 + self.haar_weights[1, 0, 1, 1] = -1 + + self.haar_weights[2, 0, 1, 0] = -1 + self.haar_weights[2, 0, 1, 1] = -1 + + self.haar_weights[3, 0, 1, 0] = -1 + self.haar_weights[3, 0, 0, 1] = -1 + + self.haar_weights = torch.cat([self.haar_weights] * self.channel_in, 0) + self.haar_weights = nn.Parameter(self.haar_weights) + self.haar_weights.requires_grad = False + + def forward(self, x, logdet=None, reverse=False): + if not reverse: + self.elements = x.shape[1] * x.shape[2] * x.shape[3] + self.last_jac = self.elements / 4 * np.log(1/16.) + + out = F.conv2d(x, self.haar_weights, bias=None, stride=2, groups=self.channel_in) / 4.0 + out = out.reshape([x.shape[0], self.channel_in, 4, x.shape[2] // 2, x.shape[3] // 2]) + out = torch.transpose(out, 1, 2) + out = out.reshape([x.shape[0], self.channel_in * 4, x.shape[2] // 2, x.shape[3] // 2]) + return out, logdet + else: + self.elements = x.shape[1] * x.shape[2] * x.shape[3] + self.last_jac = self.elements / 4 * np.log(16.) + + out = x.reshape([x.shape[0], 4, self.channel_in, x.shape[2], x.shape[3]]) + out = torch.transpose(out, 1, 2) + out = out.reshape([x.shape[0], self.channel_in * 4, x.shape[2], x.shape[3]]) + return F.conv_transpose2d(out, self.haar_weights, bias=None, stride=2, groups = self.channel_in), logdet + +class Split(nn.Module): + def __init__(self, num_channels_split, level): + super().__init__() + self.num_channels_split = num_channels_split + self.level = level + + def forward(self, z, z2=None, reverse=False): + if not reverse: + return z[:, :self.num_channels_split, ...], z[:, self.num_channels_split:, ...] + else: + return torch.cat((z, z2), dim=1) + diff --git a/codes/models/modules/ConditionalFlow.py b/codes/models/modules/ConditionalFlow.py new file mode 100644 index 0000000..0581484 --- /dev/null +++ b/codes/models/modules/ConditionalFlow.py @@ -0,0 +1,112 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from . import thops +from utils.util import opt_get +from models.modules.Basic import Conv2d, Conv2dZeros, GaussianDiag, DenseBlock, RRDB, FCN +from models.modules.FlowStep import FlowStep + +import functools +import models.modules.module_util as mutil + + +class ConditionalFlow(nn.Module): + def __init__(self, num_channels, num_channels_split, n_flow_step=0, opt=None, num_levels_condition=0, SR=True): + super().__init__() + self.SR = SR + + # number of levels of RRDB features. One level of conditional feature is enough for image rescaling + num_features_condition = 2 if self.SR else 1 + + # feature extraction + RRDB_nb = opt_get(opt, ['RRDB_nb'], [5, 5]) + RRDB_nf = opt_get(opt, ['RRDB_nf'], 64) + RRDB_gc = opt_get(opt, ['RRDB_gc'], 32) + RRDB_f = functools.partial(RRDB, nf=RRDB_nf, gc=RRDB_gc) + self.conv_first = nn.Conv2d(num_channels_split + RRDB_nf*num_features_condition*num_levels_condition, RRDB_nf, 3, 1, 1, bias=True) + self.RRDB_trunk0 = mutil.make_layer(RRDB_f, RRDB_nb[0]) + self.RRDB_trunk1 = mutil.make_layer(RRDB_f, RRDB_nb[1]) + self.trunk_conv1 = nn.Conv2d(RRDB_nf, RRDB_nf, 3, 1, 1, bias=True) + + # conditional flow + self.additional_flow_steps = nn.ModuleList() + for k in range(n_flow_step): + self.additional_flow_steps.append(FlowStep(in_channels=num_channels-num_channels_split, + cond_channels=RRDB_nf*num_features_condition, + flow_permutation=opt['flow_permutation'], + flow_coupling=opt['flow_coupling'], opt=opt)) + + self.f = Conv2dZeros(RRDB_nf*num_features_condition, (num_channels-num_channels_split)*2) + + + def forward(self, z, u, eps_std=None, logdet=0., reverse=False, training=True): + # for image SR + if self.SR: + if not reverse: + conditional_feature = self.get_conditional_feature_SR(u) + + for layer in self.additional_flow_steps: + z, logdet = layer(z, u=conditional_feature, logdet=logdet, reverse=False) + + h = self.f(conditional_feature) + mean, logs = thops.split_feature(h, "cross") + logdet += GaussianDiag.logp(mean, logs, z) + + return logdet, conditional_feature + + else: + conditional_feature = self.get_conditional_feature_SR(u) + + h = self.f(conditional_feature) + mean, logs = thops.split_feature(h, "cross") + z = GaussianDiag.sample(mean, logs, eps_std) + + for layer in reversed(self.additional_flow_steps): + z, _ = layer(z, u=conditional_feature, reverse=True) + + return z, logdet, conditional_feature + else: + # for image rescaling + if not reverse: + conditional_feature = self.get_conditional_feature_Rescaling(u) + + for layer in self.additional_flow_steps: + z, logdet = layer(z, u=conditional_feature, logdet=logdet, reverse=False) + + h = self.f(conditional_feature) + mean, scale = thops.split_feature(h, "cross") + logscale = 0.318 * torch.atan(2 * scale) + z = (z - mean) * torch.exp(-logscale) + + return z, conditional_feature + + else: + conditional_feature = self.get_conditional_feature_Rescaling(u) + + h = self.f(conditional_feature) + mean, scale = thops.split_feature(h, "cross") + logscale = 0.318 * torch.atan(2 * scale) + z = GaussianDiag.sample(mean, logscale, eps_std) + + for layer in reversed(self.additional_flow_steps): + z, _ = layer(z, u=conditional_feature, reverse=True) + + return z, conditional_feature + + + def get_conditional_feature_SR(self, u): + u_feature_first = self.conv_first(u) + u_feature1 = self.RRDB_trunk0(u_feature_first) + u_feature2 = self.trunk_conv1(self.RRDB_trunk1(u_feature1)) + u_feature_first + + return torch.cat([u_feature1, u_feature2], 1) + + def get_conditional_feature_Rescaling(self, u): + u_feature_first = self.conv_first(u) + u_feature = self.trunk_conv1(self.RRDB_trunk1(self.RRDB_trunk0(u_feature_first))) + u_feature_first + + return u_feature + + diff --git a/codes/models/modules/FlowNet_Rescaling_x4.py b/codes/models/modules/FlowNet_Rescaling_x4.py new file mode 100644 index 0000000..de8f4b3 --- /dev/null +++ b/codes/models/modules/FlowNet_Rescaling_x4.py @@ -0,0 +1,129 @@ +import numpy as np +import torch +from torch import nn as nn +import torch.nn.functional as F + +from utils.util import opt_get +from models.modules import Basic +from models.modules.FlowStep import FlowStep +from models.modules.ConditionalFlow import ConditionalFlow + +class FlowNet(nn.Module): + def __init__(self, image_shape, opt=None): + assert image_shape[2] == 1 or image_shape[2] == 3 + super().__init__() + H, W, self.C = image_shape + self.opt = opt + self.L = opt_get(opt, ['network_G', 'flowDownsampler', 'L']) + self.K = opt_get(opt, ['network_G', 'flowDownsampler', 'K']) + if isinstance(self.K, int): self.K = [self.K] * (self.L + 1) + + squeeze = opt_get(self.opt, ['network_G', 'flowDownsampler', 'squeeze'], 'checkerboard') + n_additionalFlowNoAffine = opt_get(self.opt, ['network_G', 'flowDownsampler', 'additionalFlowNoAffine'], 0) + flow_permutation = opt_get(self.opt, ['network_G', 'flowDownsampler', 'flow_permutation'], 'invconv') + flow_coupling = opt_get(self.opt, ['network_G', 'flowDownsampler', 'flow_coupling'], 'Affine') + cond_channels = opt_get(self.opt, ['network_G', 'flowDownsampler', 'cond_channels'], None) + enable_splitOff = opt_get(opt, ['network_G', 'flowDownsampler', 'splitOff', 'enable'], False) + after_splitOff_flowStep = opt_get(opt, ['network_G', 'flowDownsampler', 'splitOff', 'after_flowstep'], 0) + if isinstance(after_splitOff_flowStep, int): after_splitOff_flowStep = [after_splitOff_flowStep] * (self.L + 1) + + # construct flow + self.layers = nn.ModuleList() + self.output_shapes = [] + + for level in range(self.L): + # 1. Squeeze + if squeeze == 'checkerboard': + self.layers.append(Basic.SqueezeLayer(factor=2)) # may need a better way for squeezing + elif squeeze == 'haar': + self.layers.append(Basic.HaarDownsampling(channel_in=self.C)) + + self.C, H, W = self.C * 4, H // 2, W // 2 + self.output_shapes.append([-1, self.C, H, W]) + + # 2. main FlowSteps (uncodnitional flow) + for k in range(self.K[level]-after_splitOff_flowStep[level]): + self.layers.append(FlowStep(in_channels=self.C, cond_channels=cond_channels, + flow_permutation=flow_permutation, + flow_coupling=flow_coupling, + LRvsothers=True if k%2==0 else False, + opt=opt['network_G']['flowDownsampler'])) + self.output_shapes.append([-1, self.C, H, W]) + + # 3. additional FlowSteps (split + conditional flow) + if enable_splitOff: + if level == 0: + self.layers.append(Basic.Split(num_channels_split=self.C // 2 if level < self.L-1 else 3, level=level)) + self.level0_condFlow = ConditionalFlow(num_channels=self.C, + num_channels_split=self.C // 2 if level < self.L-1 else 3, + n_flow_step=after_splitOff_flowStep[level], + opt=opt['network_G']['flowDownsampler']['splitOff'], + num_levels_condition=1, SR=False) + elif level == 1: + self.layers.append(Basic.Split(num_channels_split=self.C // 2 if level < self.L-1 else 3, level=level)) + self.level1_condFlow = (ConditionalFlow(num_channels=self.C, + num_channels_split=self.C // 2 if level < self.L-1 else 3, + n_flow_step=after_splitOff_flowStep[level], + opt=opt['network_G']['flowDownsampler']['splitOff'], + num_levels_condition=0, SR=False)) + + self.C = self.C // 2 if level < self.L-1 else 3 + self.output_shapes.append([-1, self.C, H, W]) + + + self.H = H + self.W = W + self.scaleH = image_shape[0] / H + self.scaleW = image_shape[1] / W + print('shapes:', self.output_shapes) + + def forward(self, hr=None, z=None, u=None, eps_std=None, logdet=None, reverse=False, training=True): + if not reverse: + return self.normal_flow(hr, u=u, logdet=logdet, training=training) + else: + return self.reverse_flow(z, u=u, eps_std=eps_std, training=training) + + + ''' + hr->y1+z1->y2+z2 + ''' + def normal_flow(self, z, u=None, logdet=None, training=True): + for layer, shape in zip(self.layers, self.output_shapes): + if isinstance(layer, FlowStep): + z, _ = layer(z, u, logdet=logdet, reverse=False) + elif isinstance(layer, Basic.SqueezeLayer) or isinstance(layer, Basic.HaarDownsampling): + z, _ = layer(z, logdet=logdet, reverse=False) + elif isinstance(layer, Basic.Split): + if layer.level == 0: + z, a1 = layer(z, reverse=False) + y1 = z.clone() + elif layer.level == 1: + z, a2 = layer(z, reverse=False) + fake_z2, conditional_feature2 = self.level1_condFlow(a2, z, logdet=logdet, reverse=False, training=training) + + conditional_feature1 = torch.cat([y1, F.interpolate(conditional_feature2, scale_factor=2, mode='nearest')],1) + fake_z1, _ = self.level0_condFlow(a1, conditional_feature1, logdet=logdet, reverse=False, training=training) + + return z, fake_z1, fake_z2 + + ''' + y2+z2->y1+z1->hr + ''' + def reverse_flow(self, z, u=None, eps_std=None, training=True): + for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)): + if isinstance(layer, FlowStep): + z, _ = layer(z, u, reverse=True) + elif isinstance(layer, Basic.SqueezeLayer) or isinstance(layer, Basic.HaarDownsampling): + z, _ = layer(z, reverse=True) + elif isinstance(layer, Basic.Split): + if layer.level == 1: + a2, conditional_feature2 = self.level1_condFlow(None, z, eps_std=eps_std, reverse=True, training=training) + z = layer(z, a2, reverse=True) + elif layer.level == 0: + conditional_feature1 = torch.cat([z, F.interpolate(conditional_feature2, scale_factor=2, mode='nearest')],1) + a1, _ = self.level0_condFlow(None, conditional_feature1, eps_std=eps_std, reverse=True, training=training) + z = layer(z, a1, reverse=True) + + + return z + diff --git a/codes/models/modules/FlowNet_SR_x4.py b/codes/models/modules/FlowNet_SR_x4.py new file mode 100644 index 0000000..17e6924 --- /dev/null +++ b/codes/models/modules/FlowNet_SR_x4.py @@ -0,0 +1,124 @@ +import numpy as np +import torch +from torch import nn as nn +import torch.nn.functional as F + +from utils.util import opt_get +from models.modules import Basic +from models.modules.FlowStep import FlowStep +from models.modules.ConditionalFlow import ConditionalFlow + +class FlowNet(nn.Module): + def __init__(self, image_shape, opt=None): + assert image_shape[2] == 1 or image_shape[2] == 3 + super().__init__() + H, W, self.C = image_shape + self.opt = opt + self.L = opt_get(opt, ['network_G', 'flowDownsampler', 'L']) + self.K = opt_get(opt, ['network_G', 'flowDownsampler', 'K']) + if isinstance(self.K, int): self.K = [self.K] * (self.L + 1) + + n_additionalFlowNoAffine = opt_get(self.opt, ['network_G', 'flowDownsampler', 'additionalFlowNoAffine'], 0) + flow_permutation = opt_get(self.opt, ['network_G', 'flowDownsampler', 'flow_permutation'], 'invconv') + flow_coupling = opt_get(self.opt, ['network_G', 'flowDownsampler', 'flow_coupling'], 'Affine') + cond_channels = opt_get(self.opt, ['network_G', 'flowDownsampler', 'cond_channels'], None) + enable_splitOff = opt_get(opt, ['network_G', 'flowDownsampler', 'splitOff', 'enable'], False) + after_splitOff_flowStep = opt_get(opt, ['network_G', 'flowDownsampler', 'splitOff', 'after_flowstep'], 0) + if isinstance(after_splitOff_flowStep, int): after_splitOff_flowStep = [after_splitOff_flowStep] * (self.L + 1) + + # construct flow + self.layers = nn.ModuleList() + self.output_shapes = [] + + for level in range(self.L): + # 1. Squeeze + self.layers.append(Basic.SqueezeLayer(factor=2)) # may need a better way for squeezing + self.C, H, W = self.C * 4, H // 2, W // 2 + self.output_shapes.append([-1, self.C, H, W]) + + # 2. main FlowSteps (unconditional flow) + for k in range(self.K[level]-after_splitOff_flowStep[level]): + self.layers.append(FlowStep(in_channels=self.C, cond_channels=cond_channels, + flow_permutation=flow_permutation, + flow_coupling=flow_coupling, + opt=opt['network_G']['flowDownsampler'])) + self.output_shapes.append([-1, self.C, H, W]) + + # 3. additional FlowSteps (split + conditional flow) + if enable_splitOff: + if level == 0: + self.layers.append(Basic.Split(num_channels_split=self.C // 2 if level < self.L-1 else 3, level=level)) + self.level0_condFlow = ConditionalFlow(num_channels=self.C, + num_channels_split=self.C // 2 if level < self.L-1 else 3, + n_flow_step=after_splitOff_flowStep[level], + opt=opt['network_G']['flowDownsampler']['splitOff'], + num_levels_condition=1, SR=True) + elif level == 1: + self.layers.append(Basic.Split(num_channels_split=self.C // 2 if level < self.L-1 else 3, level=level)) + self.level1_condFlow = ConditionalFlow(num_channels=self.C, + num_channels_split=self.C // 2 if level < self.L-1 else 3, + n_flow_step=after_splitOff_flowStep[level], + opt=opt['network_G']['flowDownsampler']['splitOff'], + num_levels_condition=0, SR=True) + self.C = self.C // 2 if level < self.L-1 else 3 + self.output_shapes.append([-1, self.C, H, W]) + + + self.H = H + self.W = W + self.scaleH = image_shape[0] / H + self.scaleW = image_shape[1] / W + print('shapes:', self.output_shapes) + + # nodetach version; 0.05 better than detach version, 0.30 better when using only nll loss + def forward(self, hr=None, z=None, u=None, eps_std=None, logdet=None, reverse=False, training=True): + if not reverse: + return self.normal_flow(hr, u=u, logdet=logdet, training=training) + else: + return self.reverse_flow(z, u=u, eps_std=eps_std, training=training) + + + ''' + hr->y1+z1->y2+z2 + ''' + def normal_flow(self, z, u=None, logdet=None, training=True): + for layer, shape in zip(self.layers, self.output_shapes): + if isinstance(layer, FlowStep): + z, logdet = layer(z, u, logdet=logdet, reverse=False) + elif isinstance(layer, Basic.SqueezeLayer): + z, logdet = layer(z, logdet=logdet, reverse=False) + elif isinstance(layer, Basic.Split): + if layer.level == 0: + z, a1 = layer(z, reverse=False) + y1 = z.clone() + elif layer.level == 1: + z, a2 = layer(z, reverse=False) + logdet, conditional_feature2 = self.level1_condFlow(a2, z, logdet=logdet, reverse=False, training=training) + + conditional_feature1 = torch.cat([y1, F.interpolate(conditional_feature2, scale_factor=2, mode='nearest')],1) + logdet, _ = self.level0_condFlow(a1, conditional_feature1, logdet=logdet, reverse=False, training=training) + + return z, logdet + + ''' + y2+z2->y1+z1->hr + ''' + def reverse_flow(self, z, u=None, eps_std=None, training=True): + for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)): + if isinstance(layer, FlowStep): + z, _ = layer(z, u, reverse=True) + elif isinstance(layer, Basic.SqueezeLayer): + z, _ = layer(z, reverse=True) + elif isinstance(layer, Basic.Split): + if layer.level == 1: + a2, _, conditional_feature2 = self.level1_condFlow(None, z, eps_std=eps_std, reverse=True, training=training) + z = layer(z, a2, reverse=True) + elif layer.level == 0: + conditional_feature1 = torch.cat([z, F.interpolate(conditional_feature2, scale_factor=2, mode='nearest')],1) + a1, _, _ = self.level0_condFlow(None, conditional_feature1, eps_std=eps_std, reverse=True, training=training) + z = layer(z, a1, reverse=True) + + + + return z + diff --git a/codes/models/modules/FlowNet_SR_x8.py b/codes/models/modules/FlowNet_SR_x8.py new file mode 100644 index 0000000..b07800d --- /dev/null +++ b/codes/models/modules/FlowNet_SR_x8.py @@ -0,0 +1,145 @@ +import numpy as np +import torch +from torch import nn as nn +import torch.nn.functional as F + +from utils.util import opt_get +from models.modules import Basic +from models.modules.FlowStep import FlowStep +from models.modules.ConditionalFlow import ConditionalFlow + +class FlowNet(nn.Module): + def __init__(self, image_shape, opt=None): + assert image_shape[2] == 1 or image_shape[2] == 3 + super().__init__() + H, W, self.C = image_shape + self.opt = opt + self.L = opt_get(opt, ['network_G', 'flowDownsampler', 'L']) + self.K = opt_get(opt, ['network_G', 'flowDownsampler', 'K']) + if isinstance(self.K, int): self.K = [self.K] * (self.L + 1) + + n_additionalFlowNoAffine = opt_get(self.opt, ['network_G', 'flowDownsampler', 'additionalFlowNoAffine'], 0) + flow_permutation = opt_get(self.opt, ['network_G', 'flowDownsampler', 'flow_permutation'], 'invconv') + flow_coupling = opt_get(self.opt, ['network_G', 'flowDownsampler', 'flow_coupling'], 'Affine') + cond_channels = opt_get(self.opt, ['network_G', 'flowDownsampler', 'cond_channels'], None) + enable_splitOff = opt_get(opt, ['network_G', 'flowDownsampler', 'splitOff', 'enable'], False) + after_splitOff_flowStep = opt_get(opt, ['network_G', 'flowDownsampler', 'splitOff', 'after_flowstep'], 0) + if isinstance(after_splitOff_flowStep, int): after_splitOff_flowStep = [after_splitOff_flowStep] * (self.L + 1) + + # construct flow + self.layers = nn.ModuleList() + self.output_shapes = [] + + for level in range(self.L): + # 1. Squeeze + self.layers.append(Basic.SqueezeLayer(factor=2)) # may need a better way for squeezing + self.C, H, W = self.C * 4, H // 2, W // 2 + self.output_shapes.append([-1, self.C, H, W]) + + # 2. main FlowSteps (unconditional flow) + for k in range(self.K[level]-after_splitOff_flowStep[level]): + self.layers.append(FlowStep(in_channels=self.C, cond_channels=cond_channels, + flow_permutation=flow_permutation, + flow_coupling=flow_coupling, + opt=opt['network_G']['flowDownsampler'])) + self.output_shapes.append([-1, self.C, H, W]) + + # 3. additional FlowSteps (split + conditional flow) + if enable_splitOff: + if level == 0: + self.layers.append(Basic.Split(num_channels_split=self.C // 2 if level < self.L-1 else 3, level=level)) + self.level0_condFlow = ConditionalFlow(num_channels=self.C, + num_channels_split=self.C // 2 if level < self.L-1 else 3, + n_flow_step=after_splitOff_flowStep[level], + opt=opt['network_G']['flowDownsampler']['splitOff'], + num_levels_condition=2, SR=True) + elif level == 1: + self.layers.append(Basic.Split(num_channels_split=self.C // 2 if level < self.L-1 else 3, level=level)) + self.level1_condFlow = ConditionalFlow(num_channels=self.C, + num_channels_split=self.C // 2 if level < self.L-1 else 3, + n_flow_step=after_splitOff_flowStep[level], + opt=opt['network_G']['flowDownsampler']['splitOff'], + num_levels_condition=1, SR=True) + elif level == 2: + self.layers.append(Basic.Split(num_channels_split=self.C // 2 if level < self.L-1 else 3, level=level)) + self.level2_condFlow = ConditionalFlow(num_channels=self.C, + num_channels_split=self.C // 2 if level < self.L-1 else 3, + n_flow_step=after_splitOff_flowStep[level], + opt=opt['network_G']['flowDownsampler']['splitOff'], + num_levels_condition=0, SR=True) + + self.C = self.C // 2 if level < self.L-1 else 3 + self.output_shapes.append([-1, self.C, H, W]) + + self.H = H + self.W = W + self.scaleH = image_shape[0] / H + self.scaleW = image_shape[1] / W + print('shapes:', self.output_shapes) + + + # nodetach version: # nodetach version; 0.05 better than detach version, 0.30 better when using only nll loss + def forward(self, hr=None, z=None, u=None, eps_std=None, logdet=None, reverse=False, training=True): + if not reverse: + return self.normal_flow(hr, u=u, logdet=logdet, training=training) + else: + return self.reverse_flow(z, u=u, eps_std=eps_std, training=training) + + ''' + hr->y1+a1(z1)->y2+a2(z2)->y3+z3 + ''' + def normal_flow(self, z, u=None, logdet=None, training=True): + for layer, shape in zip(self.layers, self.output_shapes): + if isinstance(layer, FlowStep): + z, logdet = layer(z, u, logdet=logdet, reverse=False) + elif isinstance(layer, Basic.SqueezeLayer): + z, logdet = layer(z, logdet=logdet, reverse=False) + elif isinstance(layer, Basic.Split): + if layer.level == 0: + z, a1 = layer(z, reverse=False) + y1 = z.clone() + elif layer.level == 1: + z, a2 = layer(z, reverse=False) + y2 = z.clone() + elif layer.level == 2: + z, a3 = layer(z, reverse=False) + + logdet, conditional_feature3 = self.level2_condFlow(a3, z, logdet=logdet, reverse=False, training=training) + + conditional_feature2 = torch.cat([y2, F.interpolate(conditional_feature3, scale_factor=2, mode='nearest')],1) + logdet, conditional_feature2 = self.level1_condFlow(a2, conditional_feature2, logdet=logdet, reverse=False, training=training) + + conditional_feature1 = torch.cat([y1, F.interpolate(conditional_feature2, scale_factor=2, mode='nearest'), + F.interpolate(conditional_feature3, scale_factor=4, mode='nearest')],1) + logdet, _ = self.level0_condFlow(a1, conditional_feature1, logdet=logdet, reverse=False, training=training) + + return z, logdet + + ''' + y3+z3(a3)->y2+z2(a2)->y1+z1(a1)->hr + ''' + def reverse_flow(self, z, u=None, eps_std=None, training=True): + for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)): + if isinstance(layer, FlowStep): + z, _ = layer(z, u, reverse=True) + elif isinstance(layer, Basic.SqueezeLayer): + z, _ = layer(z, reverse=True) + elif isinstance(layer, Basic.Split): + if layer.level == 2: + a3, _, conditional_feature3 = self.level2_condFlow(None, z, eps_std=eps_std, reverse=True, training=training) + z = layer(z, a3, reverse=True) + elif layer.level == 1: + conditional_feature2 = torch.cat([z, F.interpolate(conditional_feature3, scale_factor=2, mode='nearest')],1) + a2, _, conditional_feature2 = self.level1_condFlow(None, conditional_feature2, eps_std=eps_std, reverse=True, training=training) + z = layer(z, a2, reverse=True) + elif layer.level == 0: + conditional_feature1 = torch.cat([z, F.interpolate(conditional_feature2, scale_factor=2, mode='nearest'), + F.interpolate(conditional_feature3, scale_factor=4, mode='nearest')],1) + a1, _, _ = self.level0_condFlow(None, conditional_feature1, eps_std=eps_std, reverse=True, training=training) + z = layer(z, a1, reverse=True) + + + + + return z + diff --git a/codes/models/modules/FlowStep.py b/codes/models/modules/FlowStep.py new file mode 100644 index 0000000..a6bc07d --- /dev/null +++ b/codes/models/modules/FlowStep.py @@ -0,0 +1,65 @@ +import torch +from torch import nn as nn + +from utils.util import opt_get +from models.modules import ActNorms, Permutations, AffineCouplings + + +class FlowStep(nn.Module): + def __init__(self, in_channels, cond_channels=None, flow_permutation='invconv', flow_coupling='Affine', LRvsothers=True, + actnorm_scale=1.0, LU_decomposed=False, opt=None): + super().__init__() + self.flow_permutation = flow_permutation + self.flow_coupling = flow_coupling + + # 1. actnorm + self.actnorm = ActNorms.ActNorm2d(in_channels, actnorm_scale) + + # 2. permute # todo: maybe hurtful for downsampling; presever the structure of downsampling + if self.flow_permutation == "invconv": + self.permute = Permutations.InvertibleConv1x1(in_channels, LU_decomposed=LU_decomposed) + elif self.flow_permutation == "none": + self.permute = None + + # 3. coupling + if self.flow_coupling == "AffineInjector": + self.affine = AffineCouplings.AffineCouplingInjector(in_channels=in_channels, cond_channels=cond_channels, opt=opt) + elif self.flow_coupling == "noCoupling": + pass + elif self.flow_coupling == "Affine": + self.affine = AffineCouplings.AffineCoupling(in_channels=in_channels, cond_channels=cond_channels, opt=opt) + elif self.flow_coupling == "Affine3shift": + self.affine = AffineCouplings.AffineCoupling3shift(in_channels=in_channels, cond_channels=cond_channels, LRvsothers=LRvsothers, opt=opt) + + def forward(self, z, u=None, logdet=None, reverse=False): + if not reverse: + return self.normal_flow(z, u, logdet) + else: + return self.reverse_flow(z, u) + + def normal_flow(self, z, u=None, logdet=None): + # 1. actnorm + z, logdet = self.actnorm(z, logdet=logdet, reverse=False) + + # 2. permute + if self.permute is not None: + z, logdet = self.permute( z, logdet=logdet, reverse=False) + + # 3. coupling + z, logdet = self.affine(z, u=u, logdet=logdet, reverse=False) + + return z, logdet + + def reverse_flow(self, z, u=None, logdet=None): + # 1.coupling + z, _ = self.affine(z, u=u, reverse=True) + + # 2. permute + if self.permute is not None: + z, _ = self.permute(z, reverse=True) + + # 3. actnorm + z, _ = self.actnorm(z, reverse=True) + + return z, logdet + diff --git a/codes/models/modules/HCFlowNet_Rescaling_arch.py b/codes/models/modules/HCFlowNet_Rescaling_arch.py new file mode 100644 index 0000000..c4839c7 --- /dev/null +++ b/codes/models/modules/HCFlowNet_Rescaling_arch.py @@ -0,0 +1,60 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from utils.util import opt_get +from models.modules.FlowNet_Rescaling_x4 import FlowNet +from models.modules import Basic, thops + + + +class HCFlowNet_Rescaling(nn.Module): + def __init__(self, opt, step=None): + super(HCFlowNet_Rescaling, self).__init__() + self.opt = opt + self.quant = opt_get(opt, ['datasets', 'train', 'quant'], 256) + + hr_size = opt_get(opt, ['datasets', 'train', 'GT_size'], 160) + hr_channel = opt_get(opt, ['network_G', 'in_nc'], 3) + + # hr->lr+z + self.flow = FlowNet((hr_size, hr_size, hr_channel), opt=opt) + + # hr: HR image, lr: LR image, z: latent variable, u: conditional variable + def forward(self, hr=None, lr=None, z=None, u=None, eps_std=None, + add_gt_noise=False, step=None, reverse=False, training=True): + + # hr->z + if not reverse: + return self.normal_flow_diracLR(hr, lr, u, step=step, training=training) + # z->hr + else: # setting z to lr!!! + return self.reverse_flow_diracLR(lr, z, u, eps_std=eps_std, training=training) + + + #########################################diracLR + # hr->lr+z, diracLR + def normal_flow_diracLR(self, hr, lr, u=None, step=None, training=True): + # 1. quantitize HR + # hr = hr + (torch.rand(hr.shape, device=hr.device)) / self.quant # no quantization is better + + # 2. hr->lr+z + fake_lr_from_hr, fake_z1, fake_z2 = self.flow(hr=hr, u=u, logdet=None, reverse=False, training=training) + + return torch.clamp(fake_lr_from_hr, 0, 1), fake_z1, fake_z2 + + # lr+z->hr + def reverse_flow_diracLR(self, lr, z, u, eps_std, training=True): + + # lr+z->hr + fake_hr = self.flow(z=lr, u=u, eps_std=eps_std, reverse=True, training=training) + + return torch.clamp(fake_hr, 0, 1) + + + def get_score(self, disc_loss_sigma, z): + score_real = 0.5 * (1 - 1 / (disc_loss_sigma ** 2)) * thops.sum(z ** 2, dim=[1, 2, 3]) - \ + z.shape[1] * z.shape[2] * z.shape[3] * math.log(disc_loss_sigma) + return -score_real \ No newline at end of file diff --git a/codes/models/modules/HCFlowNet_SR_arch.py b/codes/models/modules/HCFlowNet_SR_arch.py new file mode 100644 index 0000000..12214f7 --- /dev/null +++ b/codes/models/modules/HCFlowNet_SR_arch.py @@ -0,0 +1,75 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from utils.util import opt_get +from models.modules import Basic, thops + + + +class HCFlowNet_SR(nn.Module): + def __init__(self, opt, step=None): + super(HCFlowNet_SR, self).__init__() + self.opt = opt + self.quant = opt_get(opt, ['quant'], 256) + + hr_size = opt_get(opt, ['datasets', 'train', 'GT_size'], 160) + hr_channel = opt_get(opt, ['network_G', 'in_nc'], 3) + scale = opt_get(opt, ['scale']) + + if scale == 4: + from models.modules.FlowNet_SR_x4 import FlowNet + elif scale == 8: + from models.modules.FlowNet_SR_x8 import FlowNet + else: + raise NotImplementedError('Scale {} is not implemented'.format(scale)) + + # hr->lr+z + self.flow = FlowNet((hr_size, hr_size, hr_channel), opt=opt) + + self.quantization = Basic.Quantization() + + # hr: HR image, lr: LR image, z: latent variable, u: conditional variable + def forward(self, hr=None, lr=None, z=None, u=None, eps_std=None, + add_gt_noise=False, step=None, reverse=False, training=True): + + # hr->z + if not reverse: + return self.normal_flow_diracLR(hr, lr, u, step=step, training=training) + # z->hr + else: + return self.reverse_flow_diracLR(lr, z, u, eps_std=eps_std, training=training) + + + #########################################diracLR + # hr->lr+z, diracLR + def normal_flow_diracLR(self, hr, lr, u=None, step=None, training=True): + # 1. quantitize HR + pixels = thops.pixels(hr) + + # according to Glow and ours, it should be u~U(0,a) (0.06 better in practice), not u~U(-0.5,0.5) (though better in theory) + hr = hr + (torch.rand(hr.shape, device=hr.device)) / self.quant + logdet = torch.zeros_like(hr[:, 0, 0, 0]) + float(-np.log(self.quant) * pixels) + + # 2. hr->lr+z + fake_lr_from_hr, logdet = self.flow(hr=hr, u=u, logdet=logdet, reverse=False, training=training) + + # note in rescaling, we use LR for LR loss before quantization + fake_lr_from_hr = self.quantization(fake_lr_from_hr) + + # 3. loss, Gaussian with small variance to approximate Dirac delta function of LR. + # for the second term, using small log-variance may lead to svd problem, for both exp and tanh version + objective = logdet + Basic.GaussianDiag.logp(lr, -torch.ones_like(lr)*6, fake_lr_from_hr) + + nll = ((-objective) / float(np.log(2.) * pixels)).mean() + + return torch.clamp(fake_lr_from_hr, 0, 1), nll + + # lr+z->hr + def reverse_flow_diracLR(self, lr, z, u, eps_std, training=True): + + # lr+z->hr + fake_hr = self.flow(z=lr, u=u, eps_std=eps_std, reverse=True, training=training) + + return torch.clamp(fake_hr, 0, 1) diff --git a/codes/models/modules/Permutations.py b/codes/models/modules/Permutations.py new file mode 100644 index 0000000..a58773c --- /dev/null +++ b/codes/models/modules/Permutations.py @@ -0,0 +1,108 @@ +import numpy as np +import torch +from torch import nn as nn +from torch.nn import functional as F +import scipy.linalg + +from models.modules import thops + + +class Permute2d(nn.Module): + def __init__(self, num_channels, shuffle): + super().__init__() + self.num_channels = num_channels + self.indices = np.arange(self.num_channels - 1, -1, -1).astype(np.long) + self.indices_inverse = np.zeros((self.num_channels), dtype=np.long) + for i in range(self.num_channels): + self.indices_inverse[self.indices[i]] = i + if shuffle: + self.reset_indices() + + def reset_indices(self): + np.random.shuffle(self.indices) + for i in range(self.num_channels): + self.indices_inverse[self.indices[i]] = i + + def forward(self, input, logdet=None, reverse=False): + if not reverse: + return input[:, self.indices, :, :], logdet + else: + return input[:, self.indices_inverse, :, :], logdet + + +class InvertibleConv1x1(nn.Module): + def __init__(self, num_channels, LU_decomposed=False): + super().__init__() + w_shape = [num_channels, num_channels] + w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(np.float32) + if not LU_decomposed: + # Sample a random orthogonal matrix: + self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init))) + else: + # W = PL(U+diag(s)) + np_p, np_l, np_u = scipy.linalg.lu(w_init) + np_s = np.diag(np_u) + np_sign_s = np.sign(np_s) + np_log_s = np.log(np.abs(np_s)) + np_u = np.triu(np_u, k=1) + l_mask = np.tril(np.ones(w_shape, dtype=np.float32), -1) + eye = np.eye(*w_shape, dtype=np.float32) + + self.register_buffer('p', torch.Tensor(np_p.astype(np.float32))) # remains fixed + self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(np.float32))) # the sign is fixed + self.l = nn.Parameter(torch.Tensor(np_l.astype(np.float32))) # optimized except diagonal 1 + self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(np.float32))) + self.u = nn.Parameter(torch.Tensor(np_u.astype(np.float32))) # optimized + self.l_mask = torch.Tensor(l_mask) + self.eye = torch.Tensor(eye) + self.w_shape = w_shape + self.LU = LU_decomposed + + def get_weight(self, input, reverse): + # The difference in computational cost will become significant for large c, although for the networks in + # our experiments we did not measure a large difference in wallclock computation time. + if not self.LU: + if not reverse: + # pixels = thops.pixels(input) + # GPU version + # dlogdet = torch.slogdet(self.weight)[1] * pixels + # CPU version is 2x faster, https://github.com/didriknielsen/survae_flows/issues/5. + dlogdet = (torch.slogdet(self.weight.to('cpu'))[1] * thops.pixels(input)).to(self.weight.device) + weight = self.weight.view(self.w_shape[0], self.w_shape[1], 1, 1) + else: + dlogdet = 0 + weight = torch.inverse(self.weight.double()).float().view(self.w_shape[0], self.w_shape[1], 1, 1) + + + return weight, dlogdet + else: + self.p = self.p.to(input.device) + self.sign_s = self.sign_s.to(input.device) + self.l_mask = self.l_mask.to(input.device) + self.eye = self.eye.to(input.device) + l = self.l * self.l_mask + self.eye + u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s)) + dlogdet = thops.sum(self.log_s) * thops.pixels(input) + if not reverse: + w = torch.matmul(self.p, torch.matmul(l, u)) + else: + l = torch.inverse(l.double()).float() + u = torch.inverse(u.double()).float() + w = torch.matmul(u, torch.matmul(l, self.p.inverse())) + return w.view(self.w_shape[0], self.w_shape[1], 1, 1), dlogdet + + def forward(self, input, logdet=None, reverse=False): + """ + log-det = log|abs(|W|)| * pixels + """ + weight, dlogdet = self.get_weight(input, reverse) + if not reverse: + z = F.conv2d(input, weight) # fc layer, ie, permute channel + if logdet is not None: + logdet = logdet + dlogdet + return z, logdet + else: + z = F.conv2d(input, weight) + if logdet is not None: + logdet = logdet - dlogdet + return z, logdet diff --git a/codes/models/modules/discriminator_vgg_arch.py b/codes/models/modules/discriminator_vgg_arch.py new file mode 100644 index 0000000..d029266 --- /dev/null +++ b/codes/models/modules/discriminator_vgg_arch.py @@ -0,0 +1,189 @@ +import torch +import torch.nn as nn +import torchvision + + +class Discriminator_VGG_128(nn.Module): + def __init__(self, in_nc, nf): + super(Discriminator_VGG_128, self).__init__() + # [64, 128, 128] + self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) + self.bn0_1 = nn.BatchNorm2d(nf, affine=True) + # [64, 64, 64] + self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) + self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) + self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) + self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) + # [128, 32, 32] + self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) + self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) + self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) + self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) + # [256, 16, 16] + self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) + self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) + self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) + self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) + # [512, 8, 8] + self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) + self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) + self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) + self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) + + self.linear1 = nn.Linear(512 * 4 * 4, 100) + self.linear2 = nn.Linear(100, 1) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + fea = self.lrelu(self.conv0_0(x)) + fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) + + fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) + fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) + + fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) + fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) + + fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) + fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) + + fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) + fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) + + fea = fea.view(fea.size(0), -1) + fea = self.lrelu(self.linear1(fea)) + out = self.linear2(fea) + return out + + def reset_parameters(self): + for layer in self.children(): + if hasattr(layer, 'reset_parameters'): + layer.reset_parameters() + + + +class Discriminator_VGG_160(nn.Module): + def __init__(self, in_nc, nf): + super(Discriminator_VGG_160, self).__init__() + # [64, 160, 160] + self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) + self.bn0_1 = nn.BatchNorm2d(nf, affine=True) + # [64, 80, 80] + self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) + self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) + self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) + self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) + # [128, 40, 40] + self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) + self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) + self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) + self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) + # [256, 20, 20] + self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) + self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) + self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) + self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) + # [512, 10, 10] + self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) + self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) + self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) + self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) + # [512, 5, 5] + + self.linear1 = nn.Linear(512 * 5 * 5, 100) + self.linear2 = nn.Linear(100, 1) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + fea = self.lrelu(self.conv0_0(x)) + fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) + + fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) + fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) + + fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) + fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) + + fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) + fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) + + fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) + fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) + + fea = fea.view(fea.size(0), -1) + fea = self.lrelu(self.linear1(fea)) + out = self.linear2(fea) + return out + + def reset_parameters(self): + for layer in self.children(): + if hasattr(layer, 'reset_parameters'): + layer.reset_parameters() + + +class VGGFeatureExtractor(nn.Module): + def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, + device=torch.device('cpu')): + super(VGGFeatureExtractor, self).__init__() + self.use_input_norm = use_input_norm + if use_bn: + model = torchvision.models.vgg19_bn(pretrained=True) + else: + model = torchvision.models.vgg19(pretrained=True) + if self.use_input_norm: + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) + # [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1] + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) + # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1] + self.register_buffer('mean', mean) + self.register_buffer('std', std) + self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)]) + # No need to BP to variable + for k, v in self.features.named_parameters(): + v.requires_grad = False + + def forward(self, x): + # Assume input range is [0, 1] + if self.use_input_norm: + x = (x - self.mean) / self.std + output = self.features(x) + return output + + +class PatchGANDiscriminator(nn.Module): + """Defines a PatchGAN discriminator""" + + def __init__(self, in_nc=3, ndf=64, n_layers=35, norm_layer=nn.BatchNorm2d): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(PatchGANDiscriminator, self).__init__() + use_bias = False + kw = 3 + padw = 0 + sequence = [nn.Conv2d(in_nc, ndf, kernel_size=kw, stride=1, padding=padw), nn.LeakyReLU(0.2, True)] + + for i in range(0, n_layers): + sequence += [ + nn.Conv2d(ndf, ndf, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf), + nn.LeakyReLU(0.2, True) + ] + + sequence += [nn.Conv2d(ndf, 1, kernel_size=kw, stride=1, padding=padw, bias=use_bias)] # output 1 channel prediction map + # TODO + self.model = nn.Sequential(*sequence) + + def forward(self, x): + """Standard forward.""" + return self.model(x) diff --git a/codes/models/modules/loss.py b/codes/models/modules/loss.py new file mode 100644 index 0000000..41b528b --- /dev/null +++ b/codes/models/modules/loss.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn + + +class CharbonnierLoss(nn.Module): + """Charbonnier Loss (L1)""" + + def __init__(self, eps=1e-6): + super(CharbonnierLoss, self).__init__() + self.eps = eps + + def forward(self, x, y): + diff = x - y + loss = torch.sum(torch.sqrt(diff * diff + self.eps)) + return loss + + +# Define GAN loss: [gan(vanilla) | lsgan | wgan-gp | ragan] +class GANLoss(nn.Module): + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): + super(GANLoss, self).__init__() + self.gan_type = gan_type.lower() + self.real_label_val = real_label_val + self.fake_label_val = fake_label_val + + if self.gan_type == 'gan' or self.gan_type == 'ragan': + self.loss = nn.BCEWithLogitsLoss() + elif self.gan_type == 'lsgan': + self.loss = nn.MSELoss() + elif self.gan_type == 'wgan-gp': + + def wgan_loss(input, target): + # target is boolean + return -1 * input.mean() if target else input.mean() + + self.loss = wgan_loss + else: + raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) + + def get_target_label(self, input, target_is_real): + if self.gan_type == 'wgan-gp': + return target_is_real + if target_is_real: + return torch.empty_like(input).fill_(self.real_label_val) + else: + return torch.empty_like(input).fill_(self.fake_label_val) + + def forward(self, input, target_is_real): + target_label = self.get_target_label(input, target_is_real) + loss = self.loss(input, target_label) + return loss + + +class GradientPenaltyLoss(nn.Module): + def __init__(self, device=torch.device('cpu')): + super(GradientPenaltyLoss, self).__init__() + self.register_buffer('grad_outputs', torch.Tensor()) + self.grad_outputs = self.grad_outputs.to(device) + + def get_grad_outputs(self, input): + if self.grad_outputs.size() != input.size(): + self.grad_outputs.resize_(input.size()).fill_(1.0) + return self.grad_outputs + + def forward(self, interp, interp_crit): + grad_outputs = self.get_grad_outputs(interp_crit) + grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, + grad_outputs=grad_outputs, create_graph=True, + retain_graph=True, only_inputs=True)[0] + grad_interp = grad_interp.view(grad_interp.size(0), -1) + grad_interp_norm = grad_interp.norm(2, dim=1) + + loss = ((grad_interp_norm - 1)**2).mean() + return loss + +class ReconstructionLoss(nn.Module): + def __init__(self, losstype='l2', eps=1e-6): + super(ReconstructionLoss, self).__init__() + self.losstype = losstype + self.eps = eps + + def forward(self, x, target): + if self.losstype == 'l2': + return torch.mean(torch.sum((x - target)**2, (1, 2, 3))) + elif self.losstype == 'l1': + diff = x - target + return torch.mean(torch.sum(torch.sqrt(diff * diff + self.eps), (1, 2, 3))) + else: + print("reconstruction loss type error!") + return 0 + diff --git a/codes/models/modules/module_util.py b/codes/models/modules/module_util.py new file mode 100644 index 0000000..166355f --- /dev/null +++ b/codes/models/modules/module_util.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn +import torch.nn.init as init +import torch.nn.functional as F + + +def initialize_weights(net_l, scale=1): + if not isinstance(net_l, list): + net_l = [net_l] + for net in net_l: + for m in net.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, a=0, mode='fan_in') + m.weight.data *= scale # for residual block + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, a=0, mode='fan_in') + m.weight.data *= scale + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + init.constant_(m.weight, 1) + init.constant_(m.bias.data, 0.0) + +def initialize_weights_xavier(net_l, scale=1): + if not isinstance(net_l, list): + net_l = [net_l] + for net in net_l: + for m in net.modules(): + if isinstance(m, nn.Conv2d): + init.xavier_normal_(m.weight) + m.weight.data *= scale # for residual block + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + init.xavier_normal_(m.weight) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + init.constant_(m.weight, 1) + init.constant_(m.bias.data, 0.0) + + +def make_layer(block, n_layers): + layers = [] + for _ in range(n_layers): + layers.append(block()) + return nn.Sequential(*layers) + + +class ResidualBlock_noBN(nn.Module): + '''Residual block w/o BN + ---Conv-ReLU-Conv-+- + |________________| + ''' + + def __init__(self, nf=64): + super(ResidualBlock_noBN, self).__init__() + self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + + # initialization + initialize_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = F.relu(self.conv1(x), inplace=True) + out = self.conv2(out) + return identity + out + + +def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): + """Warp an image or feature map with optical flow + Args: + x (Tensor): size (N, C, H, W) + flow (Tensor): size (N, H, W, 2), normal value + interp_mode (str): 'nearest' or 'bilinear' + padding_mode (str): 'zeros' or 'border' or 'reflection' + + Returns: + Tensor: warped image or feature map + """ + assert x.size()[-2:] == flow.size()[1:3] + B, C, H, W = x.size() + # mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + grid = grid.type_as(x) + vgrid = grid + flow + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) + return output diff --git a/codes/models/modules/thops.py b/codes/models/modules/thops.py new file mode 100644 index 0000000..777d40a --- /dev/null +++ b/codes/models/modules/thops.py @@ -0,0 +1,53 @@ +import torch + + +def sum(tensor, dim=None, keepdim=False): + if dim is None: + # sum up all dim + return torch.sum(tensor) + else: + if isinstance(dim, int): + dim = [dim] + dim = sorted(dim) + for d in dim: + tensor = tensor.sum(dim=d, keepdim=True) + if not keepdim: + for i, d in enumerate(dim): + tensor.squeeze_(d-i) + return tensor + + +def mean(tensor, dim=None, keepdim=False): + if dim is None: + # mean all dim + return torch.mean(tensor) + else: + if isinstance(dim, int): + dim = [dim] + dim = sorted(dim) + for d in dim: + tensor = tensor.mean(dim=d, keepdim=True) + if not keepdim: + for i, d in enumerate(dim): + tensor.squeeze_(d-i) + return tensor + + + +def split_feature(tensor, type="split"): + """ + type = ["split", "cross"] + """ + C = tensor.size(1) + if type == "split": + return tensor[:, :C // 2, ...], tensor[:, C // 2:, ...] + elif type == "cross": + return tensor[:, 0::2, ...], tensor[:, 1::2, ...] + + +def cat_feature(tensor_a, tensor_b): + return torch.cat((tensor_a, tensor_b), dim=1) + + +def pixels(tensor): + return int(tensor.size(2) * tensor.size(3)) \ No newline at end of file diff --git a/codes/models/networks.py b/codes/models/networks.py new file mode 100644 index 0000000..e605266 --- /dev/null +++ b/codes/models/networks.py @@ -0,0 +1,71 @@ +import importlib +import logging +import torch +import models.modules.discriminator_vgg_arch as SRGAN_arch + +logger = logging.getLogger('base') + + +def find_model_using_name(model_name): + model_filename = "models.modules." + model_name + "_arch" + modellib = importlib.import_module(model_filename) + + model = None + target_model_name = model_name.replace('_Net', '') + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower(): + model = cls + + if model is None: + print( + "In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s." % ( + model_filename, target_model_name)) + exit(0) + + return model + +def define_Flow(opt, step): + opt_net = opt['network_G'] + which_model = opt_net['which_model_G'] + + Arch = find_model_using_name(which_model) + netG = Arch(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], + nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'], K=opt_net['flow']['K'], opt=opt, step=step) + return netG + +def define_G(opt, step): + which_model = opt['network_G']['which_model_G'] + + Arch = find_model_using_name(which_model) + netG = Arch(opt=opt, step=step) + return netG + +#### Discriminator +def define_D(opt): + opt_net = opt['network_D'] + which_model = opt_net['which_model_D'] + + if which_model == 'discriminator_vgg_128': + netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf']) + elif which_model == 'discriminator_vgg_160': + netD = SRGAN_arch.Discriminator_VGG_160(in_nc=opt_net['in_nc'], nf=opt_net['nf']) + elif which_model == 'PatchGANDiscriminator': + netD = SRGAN_arch.PatchGANDiscriminator(in_nc=opt_net['in_nc'], ndf=opt_net['ndf'], n_layers=opt_net['n_layers'],) + else: + raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) + return netD + + +#### Define Network used for Perceptual Loss +def define_F(opt, use_bn=False): + gpu_ids = opt['gpu_ids'] + device = torch.device('cuda' if gpu_ids else 'cpu') + # PyTorch pretrained VGG19-54, before ReLU. + if use_bn: + feature_layer = 49 + else: + feature_layer = 34 + netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, + use_input_norm=True, device=device) + netF.eval() # No need to train + return netF diff --git a/codes/options/options.py b/codes/options/options.py new file mode 100644 index 0000000..bc2b07f --- /dev/null +++ b/codes/options/options.py @@ -0,0 +1,138 @@ +import os +import os.path as osp +import logging +import yaml +from utils.util import OrderedYaml + +Loader, Dumper = OrderedYaml() + + +def parse(opt_path, gpu_ids=None, is_train=True): + with open(opt_path, mode='r') as f: + opt = yaml.load(f, Loader=Loader) + # export CUDA_VISIBLE_DEVICES + if gpu_ids is not None: opt['gpu_ids'] = [int(x) for x in gpu_ids.split(',')] + gpu_list = ','.join(str(x) for x in opt['gpu_ids']) + os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list + print('exporting CUDA_VISIBLE_DEVICES=' + gpu_list) + + opt['is_train'] = is_train + if opt['distortion'] == 'sr': + scale = opt['scale'] + + if 'datasets' in opt: + # datasets + for phase, dataset in opt['datasets'].items(): + phase = phase.split('_')[0] + print(dataset) + dataset['phase'] = phase + if opt['distortion'] == 'sr': + dataset['scale'] = scale + is_lmdb = False + if dataset.get('dataroot_GT', None) is not None: + dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT']) + if dataset['dataroot_GT'].endswith('lmdb'): + is_lmdb = True + # if dataset.get('dataroot_GT_bg', None) is not None: + # dataset['dataroot_GT_bg'] = osp.expanduser(dataset['dataroot_GT_bg']) + if dataset.get('dataroot_LQ', None) is not None: + dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ']) + if dataset['dataroot_LQ'].endswith('lmdb'): + is_lmdb = True + dataset['data_type'] = 'lmdb' if is_lmdb else 'img' + if dataset['mode'].endswith('mc'): # for memcached + dataset['data_type'] = 'mc' + dataset['mode'] = dataset['mode'].replace('_mc', '') + + # path + for key, path in opt['path'].items(): + if path and key in opt['path'] and key != 'strict_load': + opt['path'][key] = osp.expanduser(path) + opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) + + if is_train: + experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name']) + opt['path']['experiments_root'] = experiments_root + opt['path']['models'] = osp.join(experiments_root, 'models') + opt['path']['training_state'] = osp.join(experiments_root, 'training_state') + opt['path']['log'] = experiments_root + opt['path']['val_images'] = osp.join(experiments_root, 'val_images') + + # change some options for debug mode + if 'debug' in opt['name']: + opt['train']['val_freq'] = 8 + opt['logger']['print_freq'] = 1 + opt['logger']['save_checkpoint_freq'] = 8 + else: # test + results_root = osp.join(opt['path']['root'], 'results', opt['name']) + opt['path']['results_root'] = results_root + opt['path']['log'] = results_root + + + # network + if opt['distortion'] == 'sr': + opt['network_G']['scale'] = scale + + # relative learning rate + if 'train' in opt: + niter = opt['train']['niter'] + if 'T_period_rel' in opt['train']: + opt['train']['T_period'] = [int(x * niter) for x in opt['train']['T_period_rel']] + if 'restarts_rel' in opt['train']: + opt['train']['restarts'] = [int(x * niter) for x in opt['train']['restarts_rel']] + if 'lr_steps_rel' in opt['train']: + opt['train']['lr_steps'] = [int(x * niter) for x in opt['train']['lr_steps_rel']] + if 'lr_steps_inverse_rel' in opt['train']: + opt['train']['lr_steps_inverse'] = [int(x * niter) for x in opt['train']['lr_steps_inverse_rel']] + print(opt['train']) + + + return opt + + +def dict2str(opt, indent_l=1): + '''dict to string for logger''' + msg = '' + for k, v in opt.items(): + if isinstance(v, dict): + msg += ' ' * (indent_l * 2) + k + ':[\n' + msg += dict2str(v, indent_l + 1) + msg += ' ' * (indent_l * 2) + ']\n' + else: + msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' + return msg + + +class NoneDict(dict): + def __missing__(self, key): + return None + + +# convert to NoneDict, which return None for missing key. +def dict_to_nonedict(opt): + if isinstance(opt, dict): + new_opt = dict() + for key, sub_opt in opt.items(): + new_opt[key] = dict_to_nonedict(sub_opt) + return NoneDict(**new_opt) + elif isinstance(opt, list): + return [dict_to_nonedict(sub_opt) for sub_opt in opt] + else: + return opt + + +def check_resume(opt, resume_iter): + '''Check resume states and pretrain_model paths (overriding pretrain_paths)''' + logger = logging.getLogger('base') + if opt['path']['resume_state']: + if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get( + 'pretrain_model_D', None) is not None: + logger.warning('pretrain_model path will be ignored when resuming training.') + + opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'], '{}_G.pth'.format(resume_iter)) + logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G']) + + if opt['train']['gan_weight'] > 0: + opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'], + '{}_D.pth'.format(resume_iter)) + logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D']) diff --git a/codes/options/test/test_Rescaling_DF2K_4X_HCFlow.yml b/codes/options/test/test_Rescaling_DF2K_4X_HCFlow.yml new file mode 100644 index 0000000..178a05c --- /dev/null +++ b/codes/options/test/test_Rescaling_DF2K_4X_HCFlow.yml @@ -0,0 +1,93 @@ +#### general settings +name: 003_HCFlow_DF2K_x4_rescaling_test +suffix: ~ +use_tb_logger: true +model: HCFlow_Rescaling +distortion: sr +scale: 4 +gpu_ids: [0] + + +datasets: + test0: + name: example + mode: GTLQ + dataroot_GT: ../datasets/example_general_4X/HR + dataroot_LQ: ../datasets/example_general_4X/LR + +# test_1: +# name: Set5 +# mode: GTLQx +# dataroot_GT: ../datasets/Set5/HR +# dataroot_LQ: ../datasets/Set5/LR_bicubic/X4 +# +# test_2: +# name: Set14 +# mode: GTLQx +# dataroot_GT: ../datasets/Set14/HR +# dataroot_LQ: ../datasets/Set14/LR_bicubic/X4 +# +# test_3: +# name: BSD100 +# mode: GTLQx +# dataroot_GT: ../datasets/BSD100/HR +# dataroot_LQ: ../datasets/BSD100/LR_bicubic/X4 +# +# test_4: +# name: Urban100 +# mode: GTLQx +# dataroot_GT: ../datasets/Urban100/HR +# dataroot_LQ: ../datasets/Urban100/LR_bicubic/X4 +# +# test_5: +# name: DIV2K-validation +# mode: GTLQx +# dataroot_GT: ../datasets/DIV2K/HR +# dataroot_LQ: ../datasets/DIV2K/LR_bicubic/X4 + + +#### network structures +network_G: + which_model_G: HCFlowNet_Rescaling + in_nc: 3 + out_nc: 3 + act_norm_start_step: 100 + + flowDownsampler: + K: 14 + L: 2 + squeeze: haar # better than squeeze2d + flow_permutation: none # bettter than invconv + flow_coupling: Affine3shift # better than affine + nn_module: DenseBlock # better than FCN + hidden_channels: 32 + cond_channels: ~ + splitOff: + enable: true + after_flowstep: [6, 6] + flow_permutation: invconv + flow_coupling: Affine + stage1: True + feature_extractor: RRDB + nn_module: FCN + nn_module_last: Conv2dZeros + hidden_channels: 64 + RRDB_nb: [2,1] + RRDB_nf: 64 + RRDB_gc: 16 + + + +#### validation settings +val: + heats: [1.0] + n_sample: 1 + + +path: + strict_load: true + load_submodule: ~ + pretrain_model_G: ../experiments/pretrained_models/Rescaling_DF2K_X4_HCFlow.pth + + + diff --git a/codes/options/test/test_SR_CelebA_8X_HCFlow.yml b/codes/options/test/test_SR_CelebA_8X_HCFlow.yml new file mode 100644 index 0000000..2daf537 --- /dev/null +++ b/codes/options/test/test_SR_CelebA_8X_HCFlow.yml @@ -0,0 +1,80 @@ +#### general settings +name: 002_HCFlow_CelebA_x8_bicSR_test +suffix: ~ +use_tb_logger: true +model: HCFlow_SR +distortion: sr +scale: 8 +quant: 256 +gpu_ids: [0] + + +#### datasets +datasets: + test0: + name: example + mode: GTLQ + dataroot_GT: ../datasets/example_face_4X/HR + dataroot_LQ: ../datasets/example_face_4X/LR + +# val: +# name: SR_CelebA_8X_160_val +# mode: LRHR_PKL +# dataroot_GT: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_va.pklv4 +# dataroot_LQ: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_va_X8.pklv4 +# n_max: 20 +# +# test: +# name: SR_CelebA_8X_160_test +# mode: LRHR_PKL +# dataroot_GT: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_te.pklv4 +# dataroot_LQ: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_te_X8.pklv4 +# n_max: 5000 + + + + +#### network structures +network_G: + which_model_G: HCFlowNet_SR + in_nc: 3 + out_nc: 3 + act_norm_start_step: 100 + + flowDownsampler: + K: 26 + L: 3 + flow_permutation: invconv + flow_coupling: Affine + nn_module: FCN + hidden_channels: 64 + cond_channels: ~ + splitOff: + enable: true + after_flowstep: [13, 13, 13] + flow_permutation: invconv + flow_coupling: Affine + stage1: True + nn_module: FCN + nn_module_last: Conv2dZeros + hidden_channels: 64 + RRDB_nb: [5, 5] + RRDB_nf: 64 + RRDB_gc: 32 + + + +#### validation settings +val: + heats: [0, 0.8] + n_sample: 1 + + +path: + strict_load: true + load_submodule: ~ +# pretrain_model_G: ../experiments/pretrained_models/SR_CelebA_X8_HCFlow.pth +# pretrain_model_G: ../experiments/pretrained_models/SR_CelebA_X8_HCFlow+.pth + pretrain_model_G: ../experiments/pretrained_models/SR_CelebA_X8_HCFlow++.pth + + diff --git a/codes/options/test/test_SR_DF2K_4X_HCFlow.yml b/codes/options/test/test_SR_DF2K_4X_HCFlow.yml new file mode 100644 index 0000000..c67e5d2 --- /dev/null +++ b/codes/options/test/test_SR_DF2K_4X_HCFlow.yml @@ -0,0 +1,93 @@ +#### general settings +name: 001_HCFlow_DF2K_x4_bicSR_test +suffix: ~ +use_tb_logger: true +model: HCFlow_SR +distortion: sr +scale: 4 +quant: 64 +gpu_ids: [0] + + + +datasets: + test0: + name: example + mode: GTLQ + dataroot_GT: ../datasets/example_general_4X/HR + dataroot_LQ: ../datasets/example_general_4X/LR + +# test_1: +# name: Set5 +# mode: GTLQx +# dataroot_GT: ../datasets/Set5/HR +# dataroot_LQ: ../datasets/Set5/LR_bicubic/X4 + +# test_2: +# name: Set14 +# mode: GTLQ +# dataroot_GT: ../datasets/Set14/HR +# dataroot_LQ: ../datasets/Set14/LR_bicubic/X4 +# +# test_3: +# name: BSD100 +# mode: GTLQ +# dataroot_GT: ../datasets/BSD100/HR +# dataroot_LQ: ../datasets/BSD100/LR_bicubic/X4 +# +# test_4: +# name: Urban100 +# mode: GTLQ +# dataroot_GT: ../datasets/Urban100/HR +# dataroot_LQ: ../datasets/Urban100/LR_bicubic/X4 +# +# test_5: +# name: DIV2K-va-4X +# mode: GTLQ +# dataroot_GT: ../datasets/srflow_datasets/div2k-validation-modcrop8-gt +# dataroot_LQ: ../datasets/srflow_datasets/div2k-validation-modcrop8-x4 + + +#### network structures +network_G: + which_model_G: HCFlowNet_SR + in_nc: 3 + out_nc: 3 + act_norm_start_step: 100 + + flowDownsampler: + K: 26 + L: 2 + flow_permutation: invconv + flow_coupling: Affine + nn_module: FCN + hidden_channels: 64 + cond_channels: ~ + splitOff: + enable: true + after_flowstep: [13, 13] + flow_permutation: invconv + flow_coupling: Affine + nn_module: FCN + nn_module_last: Conv2dZeros + hidden_channels: 64 + RRDB_nb: [7, 7] + RRDB_nf: 64 + RRDB_gc: 32 + + +#### validation settings +val: + heats: [0,0, 0.9] + n_sample: 1 + + +path: + strict_load: true + load_submodule: ~ +# pretrain_model_G: ../experiments/pretrained_models/SR_DF2K_X4_HCFlow.pth +# pretrain_model_G: ../experiments/pretrained_models/SR_DF2K_X4_HCFlow+.pth + pretrain_model_G: ../experiments/pretrained_models/SR_DF2K_X4_HCFlow++.pth + + + diff --git a/codes/options/train/train_Rescaling_DF2K_4X_HCFlow.yml b/codes/options/train/train_Rescaling_DF2K_4X_HCFlow.yml new file mode 100644 index 0000000..c1bbb92 --- /dev/null +++ b/codes/options/train/train_Rescaling_DF2K_4X_HCFlow.yml @@ -0,0 +1,127 @@ +#### general settings +name: 003_DF2K_x4_rescaling_HCFlow +use_tb_logger: true +model: HCFlow_Rescaling +distortion: sr +scale: 4 +gpu_ids: [0] + + +#### datasets +datasets: + train: + name: DF2K_tr + mode: GTLQnpy + dataroot_GT: /cluster/work/cvl/jinliang_dataset/DIV2K+Flickr2K_decoded/DIV2K+Flickr2K_HR + dataroot_LQ: /cluster/work/cvl/jinliang_dataset/DIV2K+Flickr2K_decoded/DIV2K+Flickr2K_LR_bicubic/X4 + + use_shuffle: true + n_workers: 16 + batch_size: 16 + GT_size: 160 + use_flip: true + use_rot: true + color: RGB + + val: + name: Set5 + mode: GTLQx + dataroot_GT: ../datasets/Set5/HR + dataroot_LQ: ../datasets/Set5/LR_bicubic/X4 + + +# The optimization may not be stable for rescaling (+-0.1dB). A simple trick: for each stage of learning rate, +# resume training from the best model of the previous stage of learning rate. +#### network structures +network_G: + which_model_G: HCFlowNet_Rescaling + in_nc: 3 + out_nc: 3 + act_norm_start_step: 100 + + flowDownsampler: + K: 14 + L: 2 + squeeze: haar # better than squeeze2d + flow_permutation: none # bettter than invconv + flow_coupling: Affine3shift # better than affine + nn_module: DenseBlock # better than FCN + hidden_channels: 32 + cond_channels: ~ + splitOff: + enable: true + after_flowstep: [6, 6] + flow_permutation: invconv + flow_coupling: Affine + stage1: True + feature_extractor: RRDB + nn_module: FCN + nn_module_last: Conv2dZeros + hidden_channels: 64 + RRDB_nb: [2,1] + RRDB_nf: 64 + RRDB_gc: 16 + + +#### path +path: + pretrain_model_G: ~ + strict_load: true + resume_state: auto + +#### training settings: learning rate scheme, loss +train: + two_stage_opt: True + + lr_G: !!float 2.5e-4 + lr_scheme: MultiStepLR + weight_decay_G: 0 + max_grad_clip: 5 + max_grad_norm: 100 + beta1: 0.9 + beta2: 0.99 + niter: 500000 + warmup_iter: -1 # no warm up + lr_steps: [100000, 200000, 300000, 400000, 450000] + lr_gamma: 0.5 + restarts: ~ + restart_weights: ~ + eta_min: !!float 1e-8 + + weight_z: !!float 1e-5 + + pixel_criterion_lr: l2 + pixel_weight_lr: !!float 5e-2 + + eps_std_reverse: 1.0 + pixel_criterion_hr: l1 + pixel_weight_hr: 1.0 + + # perceptual loss + feature_criterion: l1 + feature_weight: 0 + + # gan loss + gan_type: gan # gan | lsgan | wgangp | ragan (patchgan uses lsgan) + gan_weight: 0 + + lr_D: 0 + beta1_D: 0.9 + beta2_D: 0.99 + D_update_ratio: 1 + D_init_iters: 1500 + + manual_seed: 0 + val_freq: !!float 5e3 + + +#### validation settings +val: + heats: [0.0, 1.0] + + +#### logger +logger: + print_freq: 200 + save_checkpoint_freq: !!float 5e3 + diff --git a/codes/options/train/train_SR_CelebA_8X_HCFlow++.yml b/codes/options/train/train_SR_CelebA_8X_HCFlow++.yml new file mode 100644 index 0000000..8a48a96 --- /dev/null +++ b/codes/options/train/train_SR_CelebA_8X_HCFlow++.yml @@ -0,0 +1,126 @@ +#### general settings +name: 002_CelebA_x8_bicSR_HCFlow++ +use_tb_logger: true +model: HCFlow_SR +distortion: sr +scale: 8 +quant: 256 +gpu_ids: [0] + + +#### datasets +datasets: + train: + name: CelebA_160_tr + mode: LRHR_PKL + dataroot_GT: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_tr.pklv4 + dataroot_LQ: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_tr_X8.pklv4 + + use_shuffle: true + n_workers: 16 + batch_size: 16 + GT_size: 160 + use_flip: true + color: RGB + val: + name: CelebA_160_va + mode: LRHR_PKL + dataroot_GT: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_va.pklv4 + dataroot_LQ: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_va_X8.pklv4 + n_max: 20 + + +#### network structures +network_G: + which_model_G: HCFlowNet_SR + in_nc: 3 + out_nc: 3 + act_norm_start_step: 100 + + flowDownsampler: + K: 26 + L: 3 + flow_permutation: invconv + flow_coupling: Affine + nn_module: FCN + hidden_channels: 64 # 64 and 128 are similar + cond_channels: ~ # affine coupling in the main trunk, testo 3 or None + splitOff: + enable: true + after_flowstep: [13, 13, 13] + flow_permutation: invconv + flow_coupling: Affine + stage1: True + nn_module: FCN + nn_module_last: Conv2dZeros + hidden_channels: 64 + RRDB_nb: [5, 5] + RRDB_nf: 64 + RRDB_gc: 32 + +network_D: + which_model_D: discriminator_vgg_160 + in_nc: 3 + nf: 64 + + +#### path +path: + pretrain_model_G: ../experiments/pretrained_models/SR_CelebA_X8_HCFlow.pth + strict_load: true + resume_state: auto + + +#### training settings: learning rate scheme, loss +train: + lr_G: !!float 1.25e-5 + lr_scheme: MultiStepLR + weight_decay_G: 0 + max_grad_clip: 5 + max_grad_norm: 100 + beta1: 0.9 + beta2: 0.99 + niter: 50000 + warmup_iter: -1 # no warm up + lr_steps: [20000, 40000] + lr_gamma: 0.5 + restarts: ~ + restart_weights: ~ + eta_min: !!float 1e-8 + + nll_weight: !!float 2e-3 + + # pixel loss + pixel_criterion_hr: l1 + pixel_weight_hr: 1.0 + + # perceptual loss + eps_std_reverse: 0.8 + feature_criterion: l1 + feature_weight: !!float 5e-2 # balance diversity and lpips + + # gan loss + gan_type: gan # gan | lsgan | wgangp | ragan (patchgan uses lsgan) + gan_weight: !!float 5e-1 + + lr_D: !!float 5e-5 + beta1_D: 0.9 + beta2_D: 0.99 + D_update_ratio: 1 + D_init_iters: 1500 + + manual_seed: 0 + val_freq: !!float 5e3 + + +#### validation settings +val: + heats: [0.0, 0.8] # 0.8 has best visual quality for face SR + n_sample: 3 + + +#### logger +logger: + print_freq: 200 + save_checkpoint_freq: !!float 5e3 + diff --git a/codes/options/train/train_SR_CelebA_8X_HCFlow+.yml b/codes/options/train/train_SR_CelebA_8X_HCFlow+.yml new file mode 100644 index 0000000..2953518 --- /dev/null +++ b/codes/options/train/train_SR_CelebA_8X_HCFlow+.yml @@ -0,0 +1,126 @@ +#### general settings +name: 002_CelebA_x8_bicSR_HCFlow+ +use_tb_logger: true +model: HCFlow_SR +distortion: sr +scale: 8 +quant: 256 +gpu_ids: [0] + + +#### datasets +datasets: + train: + name: CelebA_160_tr + mode: LRHR_PKL + dataroot_GT: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_tr.pklv4 + dataroot_LQ: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_tr_X8.pklv4 + + use_shuffle: true + n_workers: 16 + batch_size: 16 + GT_size: 160 + use_flip: true + color: RGB + val: + name: CelebA_160_va + mode: LRHR_PKL + dataroot_GT: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_va.pklv4 + dataroot_LQ: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_va_X8.pklv4 + n_max: 20 + + +#### network structures +network_G: + which_model_G: HCFlowNet_SR + in_nc: 3 + out_nc: 3 + act_norm_start_step: 100 + + flowDownsampler: + K: 26 + L: 3 + flow_permutation: invconv + flow_coupling: Affine + nn_module: FCN + hidden_channels: 64 # 64 and 128 are similar + cond_channels: ~ # affine coupling in the main trunk, testo 3 or None + splitOff: + enable: true + after_flowstep: [13, 13, 13] + flow_permutation: invconv + flow_coupling: Affine + stage1: True + nn_module: FCN + nn_module_last: Conv2dZeros + hidden_channels: 64 + RRDB_nb: [5, 5] + RRDB_nf: 64 + RRDB_gc: 32 + +network_D: + which_model_D: discriminator_vgg_160 + in_nc: 3 + nf: 64 + + +#### path +path: + pretrain_model_G: ../experiments/pretrained_models/SR_CelebA_X8_HCFlow.pth + strict_load: true + resume_state: auto + + +#### training settings: learning rate scheme, loss +train: + lr_G: !!float 5e-5 + lr_scheme: MultiStepLR + weight_decay_G: 0 + max_grad_clip: 5 + max_grad_norm: 100 + beta1: 0.9 + beta2: 0.99 + niter: 50000 + warmup_iter: -1 # no warm up + lr_steps: [20000, 40000] + lr_gamma: 0.5 + restarts: ~ + restart_weights: ~ + eta_min: !!float 1e-8 + + nll_weight: !!float 2e-3 + + # pixel loss + pixel_criterion_hr: l1 + pixel_weight_hr: 1.0 + + # perceptual loss + eps_std_reverse: 0.8 + feature_criterion: l1 + feature_weight: 0 + + # gan loss + gan_type: gan # gan | lsgan | wgangp | ragan (patchgan uses lsgan) + gan_weight: 0 + + lr_D: 0 + beta1_D: 0.9 + beta2_D: 0.99 + D_update_ratio: 1 + D_init_iters: 1500 + + manual_seed: 0 + val_freq: !!float 5e3 + + +#### validation settings +val: + heats: [0.0, 0.8] # 0.8 has best visual quality for face SR + n_sample: 3 + + +#### logger +logger: + print_freq: 200 + save_checkpoint_freq: !!float 5e3 + diff --git a/codes/options/train/train_SR_CelebA_8X_HCFlow.yml b/codes/options/train/train_SR_CelebA_8X_HCFlow.yml new file mode 100644 index 0000000..db4ca58 --- /dev/null +++ b/codes/options/train/train_SR_CelebA_8X_HCFlow.yml @@ -0,0 +1,125 @@ +#### general settings +name: 002_CelebA_x8_bicSR_HCFlow +use_tb_logger: true +model: HCFlow_SR +distortion: sr +scale: 8 +quant: 256 +gpu_ids: [0] + + +#### datasets +datasets: + train: + name: CelebA_160_tr + mode: LRHR_PKL + dataroot_GT: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_tr.pklv4 + dataroot_LQ: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_tr_X8.pklv4 + + use_shuffle: true + n_workers: 16 + batch_size: 16 + GT_size: 160 + use_flip: true + color: RGB + val: + name: CelebA_160_va + mode: LRHR_PKL + dataroot_GT: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_va.pklv4 + dataroot_LQ: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_va_X8.pklv4 + n_max: 20 + + +#### network structures +network_G: + which_model_G: HCFlowNet_SR + in_nc: 3 + out_nc: 3 + act_norm_start_step: 100 + + flowDownsampler: + K: 26 + L: 3 + flow_permutation: invconv + flow_coupling: Affine + nn_module: FCN + hidden_channels: 64 + cond_channels: ~ + splitOff: + enable: true + after_flowstep: [13, 13, 13] + flow_permutation: invconv + flow_coupling: Affine + nn_module: FCN + nn_module_last: Conv2dZeros + hidden_channels: 64 + RRDB_nb: [5, 5] + RRDB_nf: 64 + RRDB_gc: 32 + +network_D: + which_model_D: discriminator_vgg_160 + in_nc: 3 + nf: 64 + + +#### path +path: + pretrain_model_G: ~ + strict_load: true + resume_state: auto + + +#### training settings: learning rate scheme, loss +train: + lr_G: !!float 2.5e-4 + lr_scheme: MultiStepLR + weight_decay_G: 0 + max_grad_clip: 5 + max_grad_norm: 100 + beta1: 0.9 + beta2: 0.99 + niter: 350000 + warmup_iter: -1 # no warm up + lr_steps: [200000, 250000, 280000, 310000, 340000] + lr_gamma: 0.5 + restarts: ~ + restart_weights: ~ + eta_min: !!float 1e-8 + + nll_weight: 1 + + # pixel loss + pixel_criterion_hr: l1 + pixel_weight_hr: 0 + + # perceptual loss + eps_std_reverse: 0.8 + feature_criterion: l1 + feature_weight: 0 + + # gan loss + gan_type: gan # gan | lsgan | wgangp | ragan (patchgan uses lsgan) + gan_weight: 0 + + lr_D: 0 + beta1_D: 0.9 + beta2_D: 0.99 + D_update_ratio: 1 + D_init_iters: 1500 + + manual_seed: 0 + val_freq: !!float 5e3 + + +#### validation settings +val: + heats: [0.0, 0.8] + n_sample: 3 + + +#### logger +logger: + print_freq: 200 + save_checkpoint_freq: !!float 5e3 + diff --git a/codes/options/train/train_SR_DF2K_4X_HCFlow++.yml b/codes/options/train/train_SR_DF2K_4X_HCFlow++.yml new file mode 100644 index 0000000..32b6f80 --- /dev/null +++ b/codes/options/train/train_SR_DF2K_4X_HCFlow++.yml @@ -0,0 +1,124 @@ +#### general settings +name: 001_DF2K_x4_bicSR_HCFlow++ +use_tb_logger: true +model: HCFlow_SR +distortion: sr +scale: 4 +quant: 64 +gpu_ids: [0] + + +#### datasets +datasets: + train: + name: DF2K_tr + mode: LRHR_PKL + dataroot_GT: ../datasets/srflow_datasets/DF2K-tr.pklv4 + dataroot_LQ: ../datasets/srflow_datasets/DF2K-tr_X4.pklv4 + + use_shuffle: true + n_workers: 16 + batch_size: 16 + GT_size: 160 + use_flip: true + color: RGB + val: + name: Set5 + mode: GTLQx + dataroot_GT: ../datasets/Set5/HR + dataroot_LQ: ../datasets/Set5/LR_bicubic/X4 + + +#### network structures +network_G: + which_model_G: HCFlowNet_SR + in_nc: 3 + out_nc: 3 + act_norm_start_step: 100 + + flowDownsampler: + K: 26 + L: 2 + flow_permutation: invconv + flow_coupling: Affine + nn_module: FCN + hidden_channels: 64 + cond_channels: ~ + splitOff: + enable: true + after_flowstep: [13, 13] + flow_permutation: invconv + flow_coupling: Affine + nn_module: FCN + nn_module_last: Conv2dZeros + hidden_channels: 64 + RRDB_nb: [7, 7] + RRDB_nf: 64 + RRDB_gc: 32 + +network_D: + which_model_D: discriminator_vgg_160 + in_nc: 3 + nf: 64 + + +#### path +path: + pretrain_model_G: ../experiments/pretrained_models/SR_DF2K_X4_HCFlow.pth + strict_load: true + resume_state: auto + + +#### training settings: learning rate scheme, loss +train: + lr_G: !!float 1.25e-5 + lr_scheme: MultiStepLR + weight_decay_G: 0 + max_grad_clip: 5 + max_grad_norm: 100 + beta1: 0.9 + beta2: 0.99 + niter: 50000 + warmup_iter: -1 # no warm up + lr_steps: [20000, 40000] + lr_gamma: 0.5 + restarts: ~ + restart_weights: ~ + eta_min: !!float 1e-8 + + nll_weight: !!float 2e-3 + + # pixel loss + pixel_criterion_hr: l1 + pixel_weight_hr: 1.0 + + # perceptual loss + eps_std_reverse: 0.9 + feature_criterion: l1 + feature_weight: !!float 5e-2 # balance diversity and lpips + + # gan loss + gan_type: gan # gan | lsgan | wgangp | ragan (patchgan uses lsgan) + gan_weight: !!float 5e-1 + + lr_D: !!float 5e-5 + beta1_D: 0.9 + beta2_D: 0.99 + D_update_ratio: 1 + D_init_iters: 1500 + + manual_seed: 0 + val_freq: !!float 5e3 + + +#### validation settings +val: + heats: [0.0, 0.9] # 0.9 has best visual quality for general SR + n_sample: 3 + + +#### logger +logger: + print_freq: 200 + save_checkpoint_freq: !!float 5e3 + diff --git a/codes/options/train/train_SR_DF2K_4X_HCFlow+.yml b/codes/options/train/train_SR_DF2K_4X_HCFlow+.yml new file mode 100644 index 0000000..4e918b6 --- /dev/null +++ b/codes/options/train/train_SR_DF2K_4X_HCFlow+.yml @@ -0,0 +1,124 @@ +#### general settings +name: 001_DF2K_x4_bicSR_HCFlow+ +use_tb_logger: true +model: HCFlow_SR +distortion: sr +scale: 4 +quant: 64 +gpu_ids: [0] + + +#### datasets +datasets: + train: + name: DF2K_tr + mode: LRHR_PKL + dataroot_GT: ../datasets/srflow_datasets/DF2K-tr.pklv4 + dataroot_LQ: ../datasets/srflow_datasets/DF2K-tr_X4.pklv4 + + use_shuffle: true + n_workers: 16 + batch_size: 16 + GT_size: 160 + use_flip: true + color: RGB + val: + name: Set5 + mode: GTLQx + dataroot_GT: ../datasets/Set5/HR + dataroot_LQ: ../datasets/Set5/LR_bicubic/X4 + + +#### network structures +network_G: + which_model_G: HCFlowNet_SR + in_nc: 3 + out_nc: 3 + act_norm_start_step: 100 + + flowDownsampler: + K: 26 + L: 2 + flow_permutation: invconv + flow_coupling: Affine + nn_module: FCN + hidden_channels: 64 + cond_channels: ~ + splitOff: + enable: true + after_flowstep: [13, 13] + flow_permutation: invconv + flow_coupling: Affine + nn_module: FCN + nn_module_last: Conv2dZeros + hidden_channels: 64 + RRDB_nb: [7, 7] + RRDB_nf: 64 + RRDB_gc: 32 + +network_D: + which_model_D: discriminator_vgg_160 + in_nc: 3 + nf: 64 + + +#### path +path: + pretrain_model_G: ../experiments/pretrained_models/SR_DF2K_X4_HCFlow.pth + strict_load: true + resume_state: auto + + +#### training settings: learning rate scheme, loss +train: + lr_G: !!float 5e-5 + lr_scheme: MultiStepLR + weight_decay_G: 0 + max_grad_clip: 5 + max_grad_norm: 100 + beta1: 0.9 + beta2: 0.99 + niter: 50000 + warmup_iter: -1 # no warm up + lr_steps: [20000, 40000] + lr_gamma: 0.5 + restarts: ~ + restart_weights: ~ + eta_min: !!float 1e-8 + + nll_weight: !!float 2e-3 + + # pixel loss + pixel_criterion_hr: l1 + pixel_weight_hr: 1.0 + + # perceptual loss + eps_std_reverse: 0.9 + feature_criterion: l1 + feature_weight: 0 + + # gan loss + gan_type: gan # gan | lsgan | wgangp | ragan (patchgan uses lsgan) + gan_weight: 0 + + lr_D: 0 + beta1_D: 0.9 + beta2_D: 0.99 + D_update_ratio: 1 + D_init_iters: 1500 + + manual_seed: 0 + val_freq: !!float 5e3 + + +#### validation settings +val: + heats: [0.0, 0.9] # 0.9 has best visual quality for general SR + n_sample: 3 + + +#### logger +logger: + print_freq: 200 + save_checkpoint_freq: !!float 5e3 + diff --git a/codes/options/train/train_SR_DF2K_4X_HCFlow.yml b/codes/options/train/train_SR_DF2K_4X_HCFlow.yml new file mode 100644 index 0000000..c202e93 --- /dev/null +++ b/codes/options/train/train_SR_DF2K_4X_HCFlow.yml @@ -0,0 +1,124 @@ +#### general settings +name: 001_DF2K_x4_bicSR_HCFlow +use_tb_logger: true +model: HCFlow_SR +distortion: sr +scale: 4 +quant: 64 +gpu_ids: [0] + + +#### datasets +datasets: + train: + name: DF2K_tr + mode: LRHR_PKL + dataroot_GT: ../datasets/srflow_datasets/DF2K-tr.pklv4 + dataroot_LQ: ../datasets/srflow_datasets/DF2K-tr_X4.pklv4 + + use_shuffle: true + n_workers: 16 + batch_size: 16 + GT_size: 160 + use_flip: true + color: RGB + val: + name: Set5 + mode: GTLQx + dataroot_GT: ../datasets/Set5/HR + dataroot_LQ: ../datasets/Set5/LR_bicubic/X4 + + +#### network structures +network_G: + which_model_G: HCFlowNet_SR + in_nc: 3 + out_nc: 3 + act_norm_start_step: 100 + + flowDownsampler: + K: 26 + L: 2 + flow_permutation: invconv + flow_coupling: Affine + nn_module: FCN + hidden_channels: 64 + cond_channels: ~ + splitOff: + enable: true + after_flowstep: [13, 13] + flow_permutation: invconv + flow_coupling: Affine + nn_module: FCN + nn_module_last: Conv2dZeros + hidden_channels: 64 + RRDB_nb: [7, 7] + RRDB_nf: 64 + RRDB_gc: 32 + +network_D: + which_model_D: discriminator_vgg_160 + in_nc: 3 + nf: 64 + + +#### path +path: + pretrain_model_G: ~ + strict_load: true + resume_state: auto + + +#### training settings: learning rate scheme, loss +train: + lr_G: !!float 2.5e-4 + lr_scheme: MultiStepLR + weight_decay_G: 0 + max_grad_clip: 5 + max_grad_norm: 100 + beta1: 0.9 + beta2: 0.99 + niter: 300000 + warmup_iter: -1 # no warm up + lr_steps_rel: [0.5,0.75,0.9,0.95] + lr_gamma: 0.5 + restarts: ~ + restart_weights: ~ + eta_min: !!float 1e-8 + + nll_weight: 1 + + # pixel loss + pixel_criterion_hr: l1 + pixel_weight_hr: 0 + + # perceptual loss + eps_std_reverse: 0.9 + feature_criterion: l1 + feature_weight: 0 + + # gan loss + gan_type: gan # gan | lsgan | wgangp | ragan (patchgan uses lsgan) + gan_weight: 0 + + lr_D: 0 + beta1_D: 0.9 + beta2_D: 0.99 + D_update_ratio: 1 + D_init_iters: 1500 + + manual_seed: 0 + val_freq: !!float 5e3 + + +#### validation settings +val: + heats: [0.0, 0.9] + n_sample: 3 + + +#### logger +logger: + print_freq: 200 + save_checkpoint_freq: !!float 5e3 + diff --git a/codes/scripts/png2npy.py b/codes/scripts/png2npy.py new file mode 100644 index 0000000..93b6024 --- /dev/null +++ b/codes/scripts/png2npy.py @@ -0,0 +1,42 @@ +import os +import argparse +import skimage.io as sio +import numpy as np + +# usage: python scripts/png2npy.py --pathFrom ../datasets/DIV2K+Flickr2K/DIV2K+Flickr2K_HR/ --pathTo ../datasets/DIV2K+Flickr2K_decoded/DIV2K+Flickr2K_HR/ +# python scripts/png2npy.py --pathFrom ../datasets/DIV2K+Flickr2K/DIV2K+Flickr2K_LR_bicubic/ --pathTo ../datasets/DIV2K+Flickr2K_decoded/DIV2K+Flickr2K_LR_bicubic/ + + +parser = argparse.ArgumentParser(description='Pre-processing .png images') +parser.add_argument('--pathFrom', default='', + help='directory of images to convert') +parser.add_argument('--pathTo', default='', + help='directory of images to save') +parser.add_argument('--split', default=True, + help='save individual images') +parser.add_argument('--select', default='', + help='select certain path') + +args = parser.parse_args() + +for (path, dirs, files) in os.walk(args.pathFrom): + print(path) + targetDir = os.path.join(args.pathTo, path[len(args.pathFrom) + 1:]) + if len(args.select) > 0 and path.find(args.select) == -1: + continue + + if not os.path.exists(targetDir): + os.mkdir(targetDir) + + if len(dirs) == 0: + pack = {} + n = 0 + for fileName in files: + (idx, ext) = os.path.splitext(fileName) + if ext == '.png': + image = sio.imread(os.path.join(path, fileName)) + if args.split: + np.save(os.path.join(targetDir, idx + '.npy'), image) + n += 1 + if n % 100 == 0: + print('Converted ' + str(n) + ' images.') diff --git a/codes/scripts/prepare_data_pkl.py b/codes/scripts/prepare_data_pkl.py new file mode 100644 index 0000000..b056452 --- /dev/null +++ b/codes/scripts/prepare_data_pkl.py @@ -0,0 +1,119 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd. +# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode +# +# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import os +import sys + +import numpy as np +import random +import imageio + +from natsort import natsort +from tqdm import tqdm + +def get_img_paths(dir_path, wildcard='*.png'): + return natsort.natsorted(glob.glob(dir_path + '/' + wildcard)) + +def create_all_dirs(path): + if "." in path.split("/")[-1]: + dirs = os.path.dirname(path) + else: + dirs = path + os.makedirs(dirs, exist_ok=True) + +def to_pklv4(obj, path, vebose=False): + create_all_dirs(path) + with open(path, 'wb') as f: + pickle.dump(obj, f, protocol=4) + if vebose: + print("Wrote {}".format(path)) + + +from imresize import imresize + +def random_crop(img, size): + h, w, c = img.shape + + h_start = np.random.randint(0, h - size) + h_end = h_start + size + + w_start = np.random.randint(0, w - size) + w_end = w_start + size + + return img[h_start:h_end, w_start:w_end] + + +def imread(img_path): + img = imageio.imread(img_path) + if len(img.shape) == 2: + img = np.stack([img, ] * 3, axis=2) + return img + + +def to_pklv4_1pct(obj, path, vebose): + n = int(round(len(obj) * 0.01)) + path = path.replace(".", "_1pct.") + to_pklv4(obj[:n], path, vebose=True) + + +def main(dir_path): + hrs = [] + lqs = [] + + img_paths = get_img_paths(dir_path) + for img_path in tqdm(img_paths): + img = imread(img_path) + + for i in range(47): + crop = random_crop(img, 160) + cropX4 = imresize(crop, scalar_scale=0.25) + hrs.append(crop) + lqs.append(cropX4) + + shuffle_combined(hrs, lqs) + + hrs_path = get_hrs_path(dir_path) + to_pklv4(hrs, hrs_path, vebose=True) + to_pklv4_1pct(hrs, hrs_path, vebose=True) + + lqs_path = get_lqs_path(dir_path) + to_pklv4(lqs, lqs_path, vebose=True) + to_pklv4_1pct(lqs, lqs_path, vebose=True) + + +def get_hrs_path(dir_path): + base_dir = os.path.dirname(dir_path) + name = os.path.basename(dir_path) + hrs_path = os.path.join(base_dir, 'pkls', name + '.pklv4') + return hrs_path + + +def get_lqs_path(dir_path): + base_dir = os.path.dirname(dir_path) + name = os.path.basename(dir_path) + hrs_path = os.path.join(base_dir, 'pkls', name + '_X4.pklv4') + return hrs_path + + +def shuffle_combined(hrs, lqs): + combined = list(zip(hrs, lqs)) + random.shuffle(combined) + hrs[:], lqs[:] = zip(*combined) + + +if __name__ == "__main__": + dir_path = sys.argv[1] + assert os.path.isdir(dir_path) + main(dir_path) diff --git a/codes/test_HCFlow.py b/codes/test_HCFlow.py new file mode 100644 index 0000000..69b4696 --- /dev/null +++ b/codes/test_HCFlow.py @@ -0,0 +1,252 @@ +import os.path +import logging +import time +import argparse +from collections import OrderedDict +import numpy as np +import torch +import options.options as option +import utils.util as util +from utils.imresize import imresize +from data.util import bgr2ycbcr +from data import create_dataset, create_dataloader +from models import create_model +import lpips + +os.environ['CUDA_HOME'] = '/scratch_net/rind/cuda-11.0' +os.environ['LD_LIBRARY_PATH'] = '/scratch_net/rind/cuda-11.0/lib64' +os.environ['PATH'] = '/scratch_net/rind/cuda-11.0/bin' + +#### options +parser = argparse.ArgumentParser() # test_SR_CelebA_8X_HCFlow test_SR_DF2K_4X_HCFlow test_Rescaling_DF2K_4X_HCFlow +parser.add_argument('--opt', type=str, default='options/test/test_SR_CelebA_8X_HCFlow.yml', help='Path to options YMAL file.') +parser.add_argument('--save_kernel', action='store_true', default=False, help='Save Kernel Esimtation.') +args = parser.parse_args() +opt = option.parse(args.opt, is_train=False) +opt = option.dict_to_nonedict(opt) +device_id = torch.cuda.current_device() + +#### mkdir and logger +util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' + and 'pretrain_model' not in key and 'resume' not in key and 'load_submodule' not in key)) +util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, + screen=True, tofile=True) +logger = logging.getLogger('base') +logger.info(option.dict2str(opt)) + +# set random seed +util.set_random_seed(0) + +#### Create test dataset and dataloader +test_loaders = [] +for phase, dataset_opt in sorted(opt['datasets'].items()): + test_set = create_dataset(dataset_opt) + test_loader = create_dataloader(test_set, dataset_opt) + logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) + test_loaders.append(test_loader) + +# load pretrained model by default +model = create_model(opt) +loss_fn_alex = lpips.LPIPS(net='alex').to('cuda') +crop_border = opt['crop_border'] if opt['crop_border'] else opt['scale'] + +for test_loader in test_loaders: + test_set_name = test_loader.dataset.opt['name'] + logger.info('\n\nTesting [{:s}]...'.format(test_set_name)) + test_start_time = time.time() + dataset_dir = os.path.join(opt['path']['results_root'], test_set_name) + util.mkdir(dataset_dir) + + idx = 0 + psnr_dict={} # for HR image + ssim_dict={} + psnr_y_dict = {} + ssim_y_dict = {} + bic_hr_psnr_dict={} # for bic(HR) + bic_hr_ssim_dict={} + bic_hr_psnr_y_dict = {} + bic_hr_ssim_y_dict = {} + lpips_dict = {} + diversity_dict = {} # pixel-wise variance + avg_lr_psnr = 0.0 # for generated LR image + avg_lr_ssim = 0.0 + avg_lr_psnr_y = 0.0 + avg_lr_ssim_y = 0.0 + avg_nll = 0.0 + + for test_data in test_loader: + idx += 1 + + real_image = True if test_loader.dataset.opt['dataroot_GT'] is None else False + generate_online = True if test_loader.dataset.opt['dataroot_GT'] is not None and test_loader.dataset.opt[ + 'dataroot_LQ'] is None else False + img_path = test_data['LQ_path'][0] if real_image else test_data['GT_path'][0] + img_name = os.path.splitext(os.path.basename(img_path))[0] + + model.feed_data(test_data) + nll = model.test() + avg_nll += nll + visuals = model.get_current_visuals() + + # calculate PSNR for LR + gt_img_lr = util.tensor2img(visuals['LQ']) + sr_img_lr = util.tensor2img(visuals['LQ_fromH']) + # save_img_path = os.path.join(dataset_dir, 'LR_{:s}_{:.1f}_{:d}.png'.format(img_name, 1.0, 0)) + # util.save_img(sr_img_lr, save_img_path) + gt_img_lr = gt_img_lr / 255. + sr_img_lr = sr_img_lr / 255. + + lr_psnr, lr_ssim, lr_psnr_y, lr_ssim_y = util.calculate_psnr_ssim(gt_img_lr, sr_img_lr, 0) + avg_lr_psnr += lr_psnr + avg_lr_ssim += lr_ssim + avg_lr_psnr_y += lr_psnr_y + avg_lr_ssim_y += lr_ssim_y + + # deal with real-world data (just save) + if real_image: + for heat in opt['val']['heats']: + for sample in range(opt['val']['n_sample']): + sr_img = util.tensor2img(visuals['SR', heat, sample]) + + # deal with the image margins for real images + if opt['scale'] == 4: + real_crop = 3 + elif opt['scale'] == 2: + real_crop = 6 + elif opt['scale'] == 1: + real_crop = 11 + assert real_crop * opt['scale'] * 2 > opt['kernel_size'] + sr_img = sr_img[real_crop * opt['scale']:-real_crop * opt['scale'], + real_crop * opt['scale']:-real_crop * opt['scale'], :] + save_img_path = os.path.join(dataset_dir, 'SR_{:s}_{:.1f}_{:d}.png'. + format(img_name, heat, sample)) + util.save_img(sr_img, save_img_path) + + # deal with synthetic data (calculate psnr and save) + else: + for heat in opt['val']['heats']: + psnr = 0.0 + ssim = 0.0 + psnr_y = 0.0 + ssim_y = 0.0 + lpips_value = 0.0 + bic_hr_psnr = 0.0 + bic_hr_ssim = 0.0 + bic_hr_psnr_y = 0.0 + bic_hr_ssim_y = 0.0 + + sr_img_list =[] + for sample in range(opt['val']['n_sample']): + gt_img = visuals['GT'] + sr_img = visuals['SR', heat, sample] + sr_img_list.append(sr_img.unsqueeze(0)*255) + lpips_dict[(idx, heat, sample)] = float(loss_fn_alex(2 * gt_img.to('cuda') - 1, 2 * sr_img.to('cuda') - 1).cpu()) + lpips_value += lpips_dict[(idx, heat, sample)] + + gt_img = util.tensor2img(gt_img) # uint8 + sr_img = util.tensor2img(sr_img) # uint8 + suffix = opt['suffix'] + if suffix: + save_img_path = os.path.join(dataset_dir, 'SR_{:s}_{:.1f}_{:d}_{:s}.png'.format(img_name, heat, sample, suffix)) + else: + save_img_path = os.path.join(dataset_dir, 'SR_{:s}_{:.1f}_{:d}.png'.format(img_name, heat, sample)) + util.save_img(sr_img, save_img_path) + + gt_img = gt_img / 255. + sr_img = sr_img / 255. + bic_hr_gt_img = imresize(gt_img, 1 / opt['scale']) + bic_hr_sr_img = imresize(sr_img, 1 / opt['scale']) + + psnr_dict[(idx, heat, sample)], ssim_dict[(idx, heat, sample)], \ + psnr_y_dict[(idx, heat, sample)], ssim_y_dict[(idx, heat, sample)] = util.calculate_psnr_ssim(gt_img, sr_img, crop_border) + psnr += psnr_dict[(idx, heat, sample)] + ssim += ssim_dict[(idx, heat, sample)] + psnr_y += psnr_y_dict[(idx, heat, sample)] + ssim_y += ssim_y_dict[(idx, heat, sample)] + bic_hr_psnr_dict[(idx, heat, sample)], bic_hr_ssim_dict[(idx, heat, sample)], \ + bic_hr_psnr_y_dict[(idx, heat, sample)], bic_hr_ssim_y_dict[(idx, heat, sample)] = util.calculate_psnr_ssim(bic_hr_gt_img, bic_hr_sr_img, 0) + bic_hr_psnr += bic_hr_psnr_dict[(idx, heat, sample)] + bic_hr_ssim += bic_hr_ssim_dict[(idx, heat, sample)] + bic_hr_psnr_y += bic_hr_psnr_y_dict[(idx, heat, sample)] + bic_hr_ssim_y += bic_hr_ssim_y_dict[(idx, heat, sample)] + + + # mean pixel-wise variance + psnr /= opt['val']['n_sample'] + ssim /= opt['val']['n_sample'] + psnr_y /= opt['val']['n_sample'] + ssim_y /= opt['val']['n_sample'] + diversity_dict[(idx, heat)] = float(torch.cat(sr_img_list, 0).std([0]).mean().cpu()) + lpips_value /= opt['val']['n_sample'] + bic_hr_psnr /= opt['val']['n_sample'] + bic_hr_ssim /= opt['val']['n_sample'] + bic_hr_psnr_y /= opt['val']['n_sample'] + bic_hr_ssim_y /= opt['val']['n_sample'] + + + logger.info('{:20s} ({}samples),heat:{:.1f}) ' + 'HR:PSNR/SSIM/PSNR_Y/SSIM_Y/LPIPS/Diversity: {:.2f}/{:.4f}/{:.2f}/{:.4f}/{:.4f}/{:.4f}, ' + 'bicHR:PSNR/SSIM/PSNR_Y/SSIM_Y: {:.2f}/{:.4f}/{:.2f}/{:.4f}, ' + 'LR:PSNR/SSIM/PSNR_Y/SSIM_Y: {:.2f}/{:.4f}/{:.2f}/{:.4f}, NLL: {:.4f}'.format( + img_name, opt['val']['n_sample'], heat, + psnr, ssim, psnr_y, ssim_y, lpips_value, diversity_dict[(idx, heat)], + bic_hr_psnr, bic_hr_ssim, bic_hr_psnr_y, bic_hr_ssim_y, + lr_psnr, lr_ssim, lr_psnr_y, lr_ssim_y, nll)) + + # Average PSNR/SSIM results + avg_lr_psnr /= idx + avg_lr_ssim /= idx + avg_lr_psnr_y /= idx + avg_lr_ssim_y /= idx + avg_nll = avg_nll / idx + + if real_image: + logger.info('----{} ({} images), avg LR PSNR/SSIM/PSNR_K/LR_SSIM_Y: {:.2f}/{:.4f}/{:.2f}/{:.4f}\n'.format(test_set_name, idx, avg_lr_psnr, avg_lr_ssim, avg_lr_psnr_y, avg_lr_ssim_y)) + else: + logger.info('-------------------------------------------------------------------------------------') + for heat in opt['val']['heats']: + avg_psnr = 0.0 + avg_ssim = 0.0 + avg_psnr_y = 0.0 + avg_ssim_y = 0.0 + avg_lpips = 0.0 + avg_diversity = 0.0 + avg_bic_hr_psnr = 0.0 + avg_bic_hr_ssim = 0.0 + avg_bic_hr_psnr_y = 0.0 + avg_bic_hr_ssim_y = 0.0 + + for iidx in range(1, idx+1): + for sample in range(opt['val']['n_sample']): + avg_psnr += psnr_dict[(iidx, heat, sample)] + avg_ssim += ssim_dict[(iidx, heat, sample)] + avg_psnr_y += psnr_y_dict[(iidx, heat, sample)] + avg_ssim_y += ssim_y_dict[(iidx, heat, sample)] + avg_lpips += lpips_dict[(iidx, heat, sample)] + avg_bic_hr_psnr += bic_hr_psnr_dict[(iidx, heat, sample)] + avg_bic_hr_ssim += bic_hr_ssim_dict[(iidx, heat, sample)] + avg_bic_hr_psnr_y += bic_hr_psnr_y_dict[(iidx, heat, sample)] + avg_bic_hr_ssim_y += bic_hr_ssim_y_dict[(iidx, heat, sample)] + avg_diversity += diversity_dict[(iidx, heat)] + + avg_psnr = avg_psnr / idx / opt['val']['n_sample'] + avg_ssim = avg_ssim / idx / opt['val']['n_sample'] + avg_psnr_y = avg_psnr_y / idx / opt['val']['n_sample'] + avg_ssim_y = avg_ssim_y / idx / opt['val']['n_sample'] + avg_lpips = avg_lpips / idx / opt['val']['n_sample'] + avg_diversity = avg_diversity / idx + avg_bic_hr_psnr = avg_bic_hr_psnr / idx / opt['val']['n_sample'] + avg_bic_hr_ssim = avg_bic_hr_ssim / idx / opt['val']['n_sample'] + avg_bic_hr_psnr_y = avg_bic_hr_psnr_y / idx / opt['val']['n_sample'] + avg_bic_hr_ssim_y = avg_bic_hr_ssim_y / idx / opt['val']['n_sample'] + + # log + logger.info(opt['path']['pretrain_model_G']) + logger.info('----{} ({}images,{}samples,heat:{:.1f}) ' + 'average HR:PSNR/SSIM/PSNR_Y/SSIM_Y/LPIPS/Diversity: {:.2f}/{:.4f}/{:.2f}/{:.4f}/{:.4f}/{:.4f}, ' + 'bicHR:PSNR/SSIM/PSNR_Y/SSIM_Y: {:.2f}/{:.4f}/{:.2f}/{:.4f}, ' + 'LR:PSNR/SSIM/PSNR_Y/SSIM_Y: {:.2f}/{:.4f}/{:.2f}/{:.4f}, NLL: {:.4f}'.format( + test_set_name, idx, opt['val']['n_sample'], heat, + avg_psnr, avg_ssim, avg_psnr_y, avg_ssim_y, avg_lpips, avg_diversity, + avg_bic_hr_psnr, avg_bic_hr_ssim, avg_bic_hr_psnr_y, avg_bic_hr_ssim_y, + avg_lr_psnr, avg_lr_ssim, avg_lr_psnr_y, avg_lr_ssim_y, avg_nll)) diff --git a/codes/train_HCFlow.py b/codes/train_HCFlow.py new file mode 100644 index 0000000..b64c9db --- /dev/null +++ b/codes/train_HCFlow.py @@ -0,0 +1,314 @@ +import os +import math +import argparse +import random +import logging +import numpy as np +import torch +from data.data_sampler import DistIterSampler, EnlargedSampler +from data.util import bgr2ycbcr + +import options.options as option +from utils import util +from data import create_dataloader, create_dataset +from models import create_model +from utils.dist_util import get_dist_info, init_dist + +import socket +import getpass +import lpips + + + +def main(): + #### setup options + parser = argparse.ArgumentParser() # train_SR_CelebA_8X_HCFlow train_SR_DF2K_4X_HCFlow train_Rescaling_DF2K_4X_HCFlow + parser.add_argument('--opt', type=str, default='options/train/train_SR_CelebA_8X_HCFlow.yml', + help='Path to option YMAL file of MANet.') + parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument('--gpu_ids', type=str, default=None) + parser.add_argument('--job_id', type=str, default=0) + parser.add_argument('--job_path', type=str, default='') + args = parser.parse_args() + opt = option.parse(args.opt, args.gpu_ids, is_train=True) + device_id = torch.cuda.current_device() + + # convert to NoneDict, which returns None for missing keys + opt = option.dict_to_nonedict(opt) + print(torch.__version__) + + #### random seed + seed = opt['train']['manual_seed'] + if seed is None: + seed = random.randint(1, 10000) + util.set_random_seed(seed) + + # distributed settings + if args.launcher == 'none': + opt['dist'] = False + print('Disable distributed.', flush=True) + else: + opt['dist'] = True + if args.launcher == 'slurm' and 'dist_params' in opt: + init_dist(args.launcher, **opt['dist_params']) + else: + init_dist(args.launcher) + + opt['rank'], opt['world_size'] = get_dist_info() + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + #### loading resume state if exists + if opt['path'].get('resume_state', None): + resume_state_path, _ = util.get_resume_paths(opt) + if resume_state_path is None: + resume_state = None + else: + # distributed resuming: all load into default GPU + device_id = torch.cuda.current_device() + resume_state = torch.load(resume_state_path, + map_location=lambda storage, loc: storage.cuda(device_id)) + option.check_resume(opt, resume_state['iter']) # override model pretrain path with resume path + else: + resume_state = None + + #### mkdir and loggers + # normal training (rank -1) OR distributed training (rank (gpu id) 0-7) + if opt['rank'] <= 0: + if resume_state is None: + util.mkdir_and_rename( + opt['path']['experiments_root']) # rename experiment folder if exists + util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' + and 'pretrain_model' not in key and 'resume' not in key)) + + # config loggers. Before it, the log will not work + util.setup_logger('base', opt['path']['log'], 'train{}_'.format(args.job_id) + opt['name'], level=logging.INFO, + screen=True, tofile=True) + util.setup_logger('val', opt['path']['log'], 'val{}_'.format(args.job_id) + opt['name'], level=logging.INFO, + screen=True, tofile=True) + logger = logging.getLogger('base') + logger.info('{}@{}, GPU {}, Job_id {}, Job path {}'.format(getpass.getuser(), socket.gethostname(), + opt['gpu_ids'], args.job_id, args.job_path)) + logger.info(option.dict2str(opt)) + # tensorboard logger + if opt['use_tb_logger'] and 'debug' not in opt['name']: + version = float(torch.__version__[0:3]) + if version >= 1.1: # PyTorch 1.1 + from torch.utils.tensorboard import SummaryWriter + else: + logger.info( + 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) + from tensorboardX import SummaryWriter + tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) + else: + util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) + logger = logging.getLogger('base') + + # symlink the code/working dir + try: + os.symlink(args.job_path.replace('/cluster/home/{}'.format(getpass.getuser()),'/scratch/e_home'), + opt['path']['experiments_root']+'/{}'.format(os.path.basename(args.job_path))) + except: + pass + try: + os.symlink(args.job_path.replace('/cluster/home/{}'.format(getpass.getuser()),'/scratch/e_home')+'/options/train', + opt['path']['experiments_root']+'/{}'.format(args.job_id)) + except: + pass + + #### create train and val dataloader + dataset_ratio = 200 # enlarge the size of each epoch + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + train_set = create_dataset(dataset_opt) + train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) + total_iters = int(opt['train']['niter']) + total_epochs = int(math.ceil(total_iters / train_size)) + if opt['dist']: + train_sampler = DistIterSampler(train_set, opt['world_size'], opt['rank'], dataset_ratio) + total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) + else: + train_sampler = None + train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) + if opt['rank'] <= 0: + logger.info('Number of train images: {:,d}, iters: {:,d}'.format( + len(train_set), train_size)) + logger.info('Total epochs needed: {:d} for iters {:,d}'.format( + total_epochs, total_iters)) + elif phase == 'val': + val_set = create_dataset(dataset_opt) + val_loader = create_dataloader(val_set, dataset_opt, opt, None) + if opt['rank'] <= 0: + logger.info('Number of val images in [{:s}]: {:d}'.format( + dataset_opt['name'], len(val_set))) + else: + raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) + assert train_loader is not None + assert val_loader is not None + + #### create model + model = create_model(opt) + + #### resume training + if resume_state: + logger.info('Resuming training from epoch: {}, iter: {}.'.format( + resume_state['epoch'], resume_state['iter'])) + + start_epoch = resume_state['epoch'] + current_step = resume_state['iter'] + model.resume_training(resume_state) # handle optimizers and schedulers + else: + current_step = 0 + start_epoch = 0 + + #### training + + # logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) + for epoch in range(start_epoch, total_epochs + 1): + if opt['dist']: + train_sampler.set_epoch(epoch) + + for _, train_data in enumerate(train_loader): + current_step += 1 + if current_step > total_iters: + break + + #### update learning rate, schedulers + model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) + + #### training + model.feed_data(train_data) + model.optimize_parameters(current_step) + + #### log + if current_step % opt['logger']['print_freq'] == 0: + logs = model.get_current_log() + message = ' '.format( + epoch, current_step, model.get_current_learning_rate()) + for k, v in logs.items(): + message += '{:s}:{:.4e} '.format(k, v) + # tensorboard logger, but sometimes cause dead + if opt['use_tb_logger'] and 'debug' not in opt['name']: + if opt['rank'] <= 0: + tb_logger.add_scalar(k, v, current_step) + if opt['rank'] <= 0: + logger.info(message) + + #### save models and training states before validation + if current_step % opt['logger']['save_checkpoint_freq'] == 0: + if opt['rank'] <= 0: + logger.info('Saving models and training states.') + model.save(current_step) + model.save_training_state(epoch, current_step) + + # validation + if (current_step % opt['train']['val_freq'] == 0 ) and opt['rank'] <= 0: + idx = 0 + psnr_dict = {} + psnr_y_dict = {} + loss_fn_alex = lpips.LPIPS(net='alex').to('cuda') + lpips_dict = {} + diversity_dict = {} # pixel-wise variance + avg_lr_psnr_y = 0.0 + avg_nll = 0.0 + + for _, val_data in enumerate(val_loader): + idx += 1 + + model.feed_data(val_data) + avg_nll += model.test() + visuals = model.get_current_visuals() + + # create dir for each iteration + img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0] + img_dir = os.path.join(opt['path']['val_images'], str(current_step)) + util.mkdir(img_dir) + + # calculate LR psnr + gt_img_lr = util.tensor2img(visuals['LQ']) + sr_img_lr = util.tensor2img(visuals['LQ_fromH']) + gt_img_lr = gt_img_lr / 255. + sr_img_lr = sr_img_lr / 255. + + _, _, lr_psnr_y, _ = util.calculate_psnr_ssim(gt_img_lr, sr_img_lr, 0) + avg_lr_psnr_y += lr_psnr_y + + # deal with sr images + for heat in opt['val']['heats']: + sr_img_list =[] + for sample in range(opt['val']['n_sample']): + gt_img = visuals['GT'] + sr_img = visuals['SR', heat, sample] + sr_img_list.append(sr_img.unsqueeze(0)*255) + lpips_dict[(idx, heat, sample)] = float(loss_fn_alex(2 * gt_img.to('cuda') - 1, 2 * sr_img.to('cuda') - 1).cpu()) + + gt_img = util.tensor2img(gt_img) # uint8 + sr_img = util.tensor2img(sr_img) # uint8 + save_img_path = os.path.join(img_dir, 'SR_{:s}_{:.1f}_{:d}_{:d}.png'.format(img_name, heat, sample, current_step)) + util.save_img(sr_img, save_img_path) + + gt_img = gt_img / 255. + sr_img = sr_img / 255. + + crop_border = opt['crop_border'] if opt['crop_border'] else opt['scale'] + psnr_dict[(idx, heat, sample)], ssim, psnr_y_dict[(idx, heat, sample)], ssim_y = util.calculate_psnr_ssim(gt_img, sr_img, crop_border) + + # mean pixel-wise variance + diversity_dict[(idx, heat)] = float(torch.cat(sr_img_list, 0).std([0]).mean().cpu()) + + # log + logger.info('{}@{}, GPU {}, Job_id {}, Job path {}'.format(getpass.getuser(), socket.gethostname(), + opt['gpu_ids'], args.job_id, args.job_path)) + logger.info('# {}, Validation ()'.format(opt['name'], epoch, current_step)) + logger_val = logging.getLogger('val') # validation logger + logger_val.info('# {}, Validation ()'.format(opt['name'], epoch, current_step)) + + avg_lr_psnr_y = avg_lr_psnr_y / idx + avg_nll = avg_nll / idx + for heat in opt['val']['heats']: + avg_psnr = 0.0 + avg_psnr_y = 0.0 + avg_lpips = 0.0 + avg_diversity = 0.0 + + for iidx in range(1, idx+1): + for sample in range(opt['val']['n_sample']): + avg_psnr += psnr_dict[(iidx, heat, sample)] + avg_psnr_y += psnr_y_dict[(iidx, heat, sample)] + avg_lpips += lpips_dict[(iidx, heat, sample)] + avg_diversity += diversity_dict[(iidx, heat)] + + avg_psnr = avg_psnr / idx / opt['val']['n_sample'] + avg_psnr_y = avg_psnr_y / idx / opt['val']['n_sample'] + avg_lpips = avg_lpips / idx / opt['val']['n_sample'] + avg_diversity = avg_diversity / idx + + # log + logger.info('({}samples,heat:{:.1f}) PSNR/PSNR_Y/LPIPS/Diversity: {:.2f}/{:.2f}/{:.4f}/{:.4f}, LR_PSNR_Y: {:.2f}, NLL: {:.4f}'.format( + opt['val']['n_sample'], heat, avg_psnr, avg_psnr_y, avg_lpips, avg_diversity, avg_lr_psnr_y, avg_nll)) + logger_val.info('({}samples,heat:{:.1f}) PSNR/PSNR_Y/LPIPS/Diversity: {:.2f}/{:.2f}/{:.4f}/{:.4f}, LR_PSNR_Y: {:.2f}, NLL: {:.4f}'.format( + opt['val']['n_sample'], heat, avg_psnr, avg_psnr_y, avg_lpips, avg_diversity, avg_lr_psnr_y, avg_nll)) + + + # tensorboard logger + if opt['use_tb_logger'] and 'debug' not in opt['name']: + tb_logger.add_scalar('psnr_{:.1f}'.format(heat), avg_psnr, current_step) + tb_logger.add_scalar('psnr_y_{:.1f}'.format(heat), avg_psnr_y, current_step) + tb_logger.add_scalar('lpips_{:.1f}'.format(heat), avg_lpips, current_step) + tb_logger.add_scalar('diversity_{:.1f}'.format(heat), avg_diversity, current_step) + tb_logger.add_scalar('lr_psnr_y_{:.1f}'.format(heat), avg_lr_psnr_y, current_step) + tb_logger.add_scalar('nll_{:.1f}'.format(heat), avg_nll, current_step) + + del loss_fn_alex, visuals + + if opt['rank'] <= 0: + logger.info('Saving the final model.') + model.save('latest') + logger.info('End of model training.') + + +if __name__ == '__main__': + main() diff --git a/codes/utils/dist_util.py b/codes/utils/dist_util.py new file mode 100644 index 0000000..43cf4cd --- /dev/null +++ b/codes/utils/dist_util.py @@ -0,0 +1,83 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 +import functools +import os +import subprocess +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +def init_dist(launcher, backend='nccl', **kwargs): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + +def _init_dist_pytorch(backend, **kwargs): + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend, port=None): + """Initialize slurm distributed training environment. + + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput( + f'scontrol show hostname {node_list} | head -n1') + # specify master port + if port is not None: + os.environ['MASTER_PORT'] = str(port) + elif 'MASTER_PORT' in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + +def get_dist_info(): + if dist.is_available(): + initialized = dist.is_initialized() + else: + initialized = False + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def master_only(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper diff --git a/codes/utils/imresize.py b/codes/utils/imresize.py new file mode 100644 index 0000000..e920d1c --- /dev/null +++ b/codes/utils/imresize.py @@ -0,0 +1,180 @@ +# https://github.com/fatheral/matlab_imresize +# +# MIT License +# +# Copyright (c) 2020 Alex +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from __future__ import print_function +import numpy as np +from math import ceil, floor + + +def deriveSizeFromScale(img_shape, scale): + output_shape = [] + for k in range(2): + output_shape.append(int(ceil(scale[k] * img_shape[k]))) + return output_shape + + +def deriveScaleFromSize(img_shape_in, img_shape_out): + scale = [] + for k in range(2): + scale.append(1.0 * img_shape_out[k] / img_shape_in[k]) + return scale + + +def triangle(x): + x = np.array(x).astype(np.float64) + lessthanzero = np.logical_and((x >= -1), x < 0) + greaterthanzero = np.logical_and((x <= 1), x >= 0) + f = np.multiply((x + 1), lessthanzero) + np.multiply((1 - x), greaterthanzero) + return f + + +def cubic(x): + x = np.array(x).astype(np.float64) + absx = np.absolute(x) + absx2 = np.multiply(absx, absx) + absx3 = np.multiply(absx2, absx) + f = np.multiply(1.5 * absx3 - 2.5 * absx2 + 1, absx <= 1) + np.multiply(-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2, + (1 < absx) & (absx <= 2)) + return f + + +def contributions(in_length, out_length, scale, kernel, k_width): + if scale < 1: + h = lambda x: scale * kernel(scale * x) + kernel_width = 1.0 * k_width / scale + else: + h = kernel + kernel_width = k_width + x = np.arange(1, out_length + 1).astype(np.float64) + u = x / scale + 0.5 * (1 - 1 / scale) + left = np.floor(u - kernel_width / 2) + P = int(ceil(kernel_width)) + 2 + ind = np.expand_dims(left, axis=1) + np.arange(P) - 1 # -1 because indexing from 0 + indices = ind.astype(np.int32) + weights = h(np.expand_dims(u, axis=1) - indices - 1) # -1 because indexing from 0 + weights = np.divide(weights, np.expand_dims(np.sum(weights, axis=1), axis=1)) + aux = np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))).astype(np.int32) + indices = aux[np.mod(indices, aux.size)] + ind2store = np.nonzero(np.any(weights, axis=0)) + weights = weights[:, ind2store] + indices = indices[:, ind2store] + return weights, indices + + +def imresizemex(inimg, weights, indices, dim): + in_shape = inimg.shape + w_shape = weights.shape + out_shape = list(in_shape) + out_shape[dim] = w_shape[0] + outimg = np.zeros(out_shape) + if dim == 0: + for i_img in range(in_shape[1]): + for i_w in range(w_shape[0]): + w = weights[i_w, :] + ind = indices[i_w, :] + im_slice = inimg[ind, i_img].astype(np.float64) + outimg[i_w, i_img] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0) + elif dim == 1: + for i_img in range(in_shape[0]): + for i_w in range(w_shape[0]): + w = weights[i_w, :] + ind = indices[i_w, :] + im_slice = inimg[i_img, ind].astype(np.float64) + outimg[i_img, i_w] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0) + if inimg.dtype == np.uint8: + outimg = np.clip(outimg, 0, 255) + return np.around(outimg).astype(np.uint8) + else: + return outimg + + +def imresizevec(inimg, weights, indices, dim): + wshape = weights.shape + if dim == 0: + weights = weights.reshape((wshape[0], wshape[2], 1, 1)) + outimg = np.sum(weights * ((inimg[indices].squeeze(axis=1)).astype(np.float64)), axis=1) + elif dim == 1: + weights = weights.reshape((1, wshape[0], wshape[2], 1)) + outimg = np.sum(weights * ((inimg[:, indices].squeeze(axis=2)).astype(np.float64)), axis=2) + if inimg.dtype == np.uint8: + outimg = np.clip(outimg, 0, 255) + return np.around(outimg).astype(np.uint8) + else: + return outimg + + +def resizeAlongDim(A, dim, weights, indices, mode="vec"): + if mode == "org": + out = imresizemex(A, weights, indices, dim) + else: + out = imresizevec(A, weights, indices, dim) + return out + + +def imresize(I, scalar_scale=None, method='bicubic', output_shape=None, mode="vec"): + if method is 'bicubic': + kernel = cubic + elif method is 'bilinear': + kernel = triangle + else: + print('Error: Unidentified method supplied') + + kernel_width = 4.0 + # Fill scale and output_size + if scalar_scale is not None: + scalar_scale = float(scalar_scale) + scale = [scalar_scale, scalar_scale] + output_size = deriveSizeFromScale(I.shape, scale) + elif output_shape is not None: + scale = deriveScaleFromSize(I.shape, output_shape) + output_size = list(output_shape) + else: + print('Error: scalar_scale OR output_shape should be defined!') + return + scale_np = np.array(scale) + order = np.argsort(scale_np) + weights = [] + indices = [] + for k in range(2): + w, ind = contributions(I.shape[k], output_size[k], scale[k], kernel, kernel_width) + weights.append(w) + indices.append(ind) + B = np.copy(I) + flag2D = False + if B.ndim == 2: + B = np.expand_dims(B, axis=2) + flag2D = True + for k in range(2): + dim = order[k] + B = resizeAlongDim(B, dim, weights[dim], indices[dim], mode) + if flag2D: + B = np.squeeze(B, axis=2) + return B + + +def convertDouble2Byte(I): + B = np.clip(I, 0.0, 1.0) + B = 255 * B + return np.around(B).astype(np.uint8) \ No newline at end of file diff --git a/codes/utils/misc.py b/codes/utils/misc.py new file mode 100644 index 0000000..1d4e5eb --- /dev/null +++ b/codes/utils/misc.py @@ -0,0 +1,467 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import os +import subprocess +import time +from collections import defaultdict, deque +import datetime +import pickle +from typing import Optional, List + +import torch +import torch.distributed as dist +from torch import Tensor + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision +if float(torchvision.__version__[:3]) < 0.7: + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + if torch.cuda.is_available(): + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + else: + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() + sha = 'N/A' + diff = "clean" + branch = 'N/A' + try: + sha = _run(['git', 'rev-parse', 'HEAD']) + subprocess.check_output(['git', 'diff'], cwd=cwd) + diff = _run(['git', 'diff-index', 'HEAD']) + diff = "has uncommited changes" if diff else "clean" + branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +def collate_fn(batch): + batch = list(zip(*batch)) + batch[0] = nested_tensor_from_tensor_list(batch[0]) + return tuple(batch) + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], :img.shape[2]] = False + else: + raise ValueError('not supported') + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if float(torchvision.__version__[:3]) < 0.7: + if input.numel() > 0: + return torch.nn.functional.interpolate( + input, size, scale_factor, mode, align_corners + ) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/codes/utils/timer.py b/codes/utils/timer.py new file mode 100644 index 0000000..de15320 --- /dev/null +++ b/codes/utils/timer.py @@ -0,0 +1,62 @@ +import time + + +class ScopeTimer: + def __init__(self, name): + self.name = name + + def __enter__(self): + self.start = time.time() + return self + + def __exit__(self, *args): + self.end = time.time() + self.interval = self.end - self.start + print("{} {:.3E}".format(self.name, self.interval)) + + +class Timer: + def __init__(self): + self.times = [] + + def tick(self): + self.times.append(time.time()) + + def get_average_and_reset(self): + if len(self.times) < 2: + return -1 + avg = (self.times[-1] - self.times[0]) / (len(self.times) - 1) + self.times = [self.times[-1]] + return avg + + def get_last_iteration(self): + if len(self.times) < 2: + return 0 + return self.times[-1] - self.times[-2] + + +class TickTock: + def __init__(self): + self.time_pairs = [] + self.current_time = None + + def tick(self): + self.current_time = time.time() + + def tock(self): + assert self.current_time is not None, self.current_time + self.time_pairs.append([self.current_time, time.time()]) + self.current_time = None + + def get_average_and_reset(self): + if len(self.time_pairs) == 0: + return -1 + deltas = [t2 - t1 for t1, t2 in self.time_pairs] + avg = sum(deltas) / len(deltas) + self.time_pairs = [] + return avg + + def get_last_iteration(self): + if len(self.time_pairs) == 0: + return -1 + return self.time_pairs[-1][1] - self.time_pairs[-1][0] diff --git a/codes/utils/util.py b/codes/utils/util.py new file mode 100644 index 0000000..bfd0b7f --- /dev/null +++ b/codes/utils/util.py @@ -0,0 +1,1243 @@ +import os +import sys +import time +import math +from datetime import datetime +import random +import logging +from collections import OrderedDict +import glob + +import natsort +import numpy as np +import cv2 +import torch +from torchvision.utils import make_grid +from shutil import get_terminal_size +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from PIL import Image +import collections +from scipy import signal +try: + import accimage +except ImportError: + accimage = None + +import yaml + +try: + from yaml import CLoader as Loader, CDumper as Dumper +except ImportError: + from yaml import Loader, Dumper + +import scipy +import matplotlib +matplotlib.use('PS') +import matplotlib.pyplot as plt +from scipy.interpolate import interp2d + +def OrderedYaml(): + '''yaml orderedDict support''' + _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG + + def dict_representer(dumper, data): + return dumper.represent_dict(data.items()) + + def dict_constructor(loader, node): + return OrderedDict(loader.construct_pairs(node)) + + Dumper.add_representer(OrderedDict, dict_representer) + Loader.add_constructor(_mapping_tag, dict_constructor) + return Loader, Dumper + + +def _is_pil_image(img): + if accimage is not None: + return isinstance(img, (Image.Image, accimage.Image)) + else: + return isinstance(img, Image.Image) + + +def _is_tensor_image(img): + return torch.is_tensor(img) and img.ndimension() == 3 + + +def _is_numpy_image(img): + return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) + + +def to_pil_image(pic, mode=None): + """Convert a tensor or an ndarray to PIL Image. + + See :class:`~torchvision.transforms.ToPIlImage` for more details. + + Args: + pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. + mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). + + .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes + + Returns: + PIL Image: Image converted to PIL Image. + """ + if not(_is_numpy_image(pic) or _is_tensor_image(pic)): + raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic))) + + npimg = pic + if isinstance(pic, torch.FloatTensor): + pic = pic.mul(255).byte() + if torch.is_tensor(pic): + npimg = np.transpose(pic.numpy(), (1, 2, 0)) + + if not isinstance(npimg, np.ndarray): + raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' + + 'not {}'.format(type(npimg))) + + if npimg.shape[2] == 1: + expected_mode = None + npimg = npimg[:, :, 0] + if npimg.dtype == np.uint8: + expected_mode = 'L' + if npimg.dtype == np.int16: + expected_mode = 'I;16' + if npimg.dtype == np.int32: + expected_mode = 'I' + elif npimg.dtype == np.float32: + expected_mode = 'F' + if mode is not None and mode != expected_mode: + raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}" + .format(mode, np.dtype, expected_mode)) + mode = expected_mode + + elif npimg.shape[2] == 4: + permitted_4_channel_modes = ['RGBA', 'CMYK'] + if mode is not None and mode not in permitted_4_channel_modes: + raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes)) + + if mode is None and npimg.dtype == np.uint8: + mode = 'RGBA' + else: + permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV'] + if mode is not None and mode not in permitted_3_channel_modes: + raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes)) + if mode is None and npimg.dtype == np.uint8: + mode = 'RGB' + + if mode is None: + raise TypeError('Input type {} is not supported'.format(npimg.dtype)) + + return Image.fromarray(npimg, mode=mode) + + +def to_tensor(pic): + """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. + + See ``ToTensor`` for more details. + + Args: + pic (PIL Image or numpy.ndarray): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ + if not(_is_pil_image(pic) or _is_numpy_image(pic)): + raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) + + if isinstance(pic, np.ndarray): + # handle numpy array + img = torch.from_numpy(pic.transpose((2, 0, 1))) + # backward compatibility + return img.float().div(255) + + if accimage is not None and isinstance(pic, accimage.Image): + nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) + pic.copyto(nppic) + return torch.from_numpy(nppic) + + # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + # put it from HWC to CHW format + # yikes, this transpose takes 80% of the loading time/CPU + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float().div(255) + else: + return img + + +def resize(img, size, interpolation=Image.BILINEAR): + """Resize the input PIL Image to the given size. + + Args: + img (PIL Image): Image to be resized. + size (sequence or int): Desired output size. If size is a sequence like + (h, w), the output size will be matched to this. If size is an int, + the smaller edge of the image will be matched to this number maintaing + the aspect ratio. i.e, if height > width, then image will be rescaled to + (size * height / width, size) + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR`` + + Returns: + PIL Image: Resized image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)): + raise TypeError('Got inappropriate size arg: {}'.format(size)) + + if isinstance(size, int): + w, h = img.size + if (w <= h and w == size) or (h <= w and h == size): + return img + if w < h: + ow = size + oh = int(size * h / w) + return img.resize((ow, oh), interpolation) + else: + oh = size + ow = int(size * w / h) + return img.resize((ow, oh), interpolation) + else: + return img.resize(size[::-1], interpolation) + + +#################### +# PCA +#################### + +def PCA(data, k=2): + X = torch.from_numpy(data) + X_mean = torch.mean(X, 0) + X = X - X_mean.expand_as(X) + U, S, V = torch.svd(torch.t(X)) + return U[:, :k] # PCA matrix + +def cal_sigma(sig_x, sig_y, radians): + D = np.array([[sig_x ** 2, 0], [0, sig_y ** 2]]) + U = np.array([[np.cos(radians), -np.sin(radians)], [np.sin(radians), 1 * np.cos(radians)]]) + sigma = np.dot(U, np.dot(D, U.T)) + return sigma + +#################### +# anisotropic gaussian kernels, identical to 'mvnpdf(X,mu,sigma)' in matlab +# due to /np.sqrt((2*np.pi)**2 * sig1*sig2), `sig1=sig2=8` != `sigma=8` in matlab +# rotation matrix [[cos, -sin],[sin, cos]] +#################### + +def anisotropic_gaussian_kernel_matlab(l, sig1, sig2, theta, tensor=False): + # mean = [0, 0] + # v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + # V = np.array([[v[0], v[1]], [v[1], -v[0]]]) # [[cos, sin], [sin, -cos]] + # D = np.array([[sig1, 0], [0, sig2]]) + # cov = np.dot(np.dot(V, D), V) # VD(V^-1), V=V^-1 + + cov11 = sig1*np.cos(theta)**2 + sig2*np.sin(theta)**2 + cov22 = sig1*np.sin(theta)**2 + sig2*np.cos(theta)**2 + cov21 = (sig1-sig2)*np.cos(theta)*np.sin(theta) + cov = np.array([[cov11, cov21], [cov21, cov22]]) + + center = l / 2.0 - 0.5 + x, y = np.mgrid[-center:-center+l:1, -center:-center+l:1] + pos = np.dstack((y, x)) + k = scipy.stats.multivariate_normal.pdf(pos, mean=[0, 0], cov=cov) + + k[k < scipy.finfo(float).eps * k.max()] = 0 + sumk = k.sum() + if sumk != 0: + k = k/sumk + + return torch.FloatTensor(k) if tensor else k + +#################### +# isotropic gaussian kernels, identical to 'fspecial('gaussian',hsize,sigma)' in matlab +#################### + +def isotropic_gaussian_kernel_matlab(l, sigma, tensor=False): + center = [(l-1.0)/2.0, (l-1.0)/2.0] + [x, y] = np.meshgrid(np.arange(-center[1], center[1]+1), np.arange(-center[0], center[0]+1)) + arg = -(x*x + y*y)/(2*sigma*sigma) + k = np.exp(arg) + + k[k < scipy.finfo(float).eps * k.max()] = 0 + sumk = k.sum() + if sumk != 0: + k = k/sumk + + return torch.FloatTensor(k) if tensor else k + +#################### +# random/stable ani/isotropic gaussian kernel batch generation +#################### + +def random_anisotropic_gaussian_kernel(l=15, sig_min=0.2, sig_max=4.0, scale=3, tensor=False): + sig1 = sig_min + (sig_max-sig_min)*np.random.rand() + sig2 = sig_min + (sig1-sig_min)*np.random.rand() + theta = np.pi*np.random.rand() + + k = anisotropic_gaussian_kernel_matlab(l=l, sig1=sig1, sig2=sig2, theta=theta, tensor=tensor) + return k, np.array([sig1, sig2, theta]) + +def stable_anisotropic_gaussian_kernel(l=15, sig1=2.6, sig2=2.6, theta=0, scale=3, tensor=False): + k = anisotropic_gaussian_kernel_matlab(l=l, sig1=sig1, sig2=sig2, theta=theta, tensor=tensor) + return k, np.array([sig1, sig2, theta]) + +def random_isotropic_gaussian_kernel(l=21, sig_min=0.2, sig_max=4.0, scale=3, tensor=False): + x = np.random.random() * (sig_max - sig_min) + sig_min + k = isotropic_gaussian_kernel_matlab(l, x, tensor=tensor) + return k, np.array([x, x, 0]) + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf-1)*0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w-1) + y1 = np.clip(y1, 0, h-1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def stable_isotropic_gaussian_kernel(l=21, sig=2.6, scale=3, tensor=False): + k = isotropic_gaussian_kernel_matlab(l, sig, tensor=tensor) + # shift version 1: interpolation + # k = shift_pixel(k, scale) + # k = k/k.sum() + return k, np.array([sig, sig, 0]) + +def random_gaussian_kernel(l=21, sig_min=0.2, sig_max=4.0, rate_iso=1.0, scale=3, tensor=False): + if np.random.random() < rate_iso: + return random_isotropic_gaussian_kernel(l=l, sig_min=sig_min, sig_max=sig_max, scale=scale, tensor=tensor) + else: + return random_anisotropic_gaussian_kernel(l=l, sig_min=sig_min, sig_max=sig_max, scale=scale, tensor=tensor) + +def stable_gaussian_kernel(l=21, sig=2.6, sig1=2.6, sig2=2.6, theta=0, rate_iso=1.0, scale=3, tensor=False): + if np.random.random() < rate_iso: + return stable_isotropic_gaussian_kernel(l=l, sig=sig, scale=scale, tensor=tensor) + else: + return stable_anisotropic_gaussian_kernel(l=l, sig1=sig1, sig2=sig2, theta=theta, scale=scale, tensor=tensor) + +# only these two func can be used outside this script +def random_batch_kernel(batch, l=21, sig_min=0.2, sig_max=4.0, rate_iso=1.0, scale=3, tensor=True): + batch_kernel = np.zeros((batch, l, l)) + batch_sigma = np.zeros((batch, 3)) + shifted_l = l - scale + 1 + for i in range(batch): + batch_kernel[i, :shifted_l, :shifted_l], batch_sigma[i, :] = \ + random_gaussian_kernel(l=shifted_l, sig_min=sig_min, sig_max=sig_max, rate_iso=rate_iso, scale=scale, tensor=False) + if tensor: + return torch.FloatTensor(batch_kernel), torch.FloatTensor(batch_sigma) + else: + return batch_kernel, batch_sigma + +def stable_batch_kernel(batch, l=21, sig=2.6, sig1=2.6, sig2=2.6, theta=0, rate_iso=1.0, scale=3, tensor=True): + batch_kernel = np.zeros((batch, l, l)) + batch_sigma = np.zeros((batch, 3)) + shifted_l = l - scale + 1 + for i in range(batch): + batch_kernel[i, :shifted_l, :shifted_l], batch_sigma[i, :] = \ + stable_gaussian_kernel(l=shifted_l, sig=sig, sig1=sig1, sig2=sig2, theta=theta, rate_iso=rate_iso, scale=scale, tensor=False) + if tensor: + return torch.FloatTensor(batch_kernel), torch.FloatTensor(batch_sigma) + else: + return batch_kernel, batch_sigma + +# for SVKE, MutualAffineConv +def stable_batch_kernel_SV_mode(batch, img_H=250, img_W=250, divide_H=1, divide_W=1, sv_mode=0, l=21, sig=2.6, sig1=2.6, sig2=2.6, theta=0, rate_iso=1.0, scale=3, tensor=True): + batch_kernel = np.zeros((batch, img_H*img_W, l, l)) + batch_sigma = np.zeros((batch, img_H*img_W, 3)) + shifted_l = l - scale + 1 + a = (2.5-0.175)*scale + b = 0.175*scale + for ibatch in range(batch): + block_H = math.ceil(img_H/divide_H) + block_W = math.ceil(img_W/divide_W) + for h in range(block_H): + for w in range(block_W): + if sv_mode == 1: + sig1 = a + b + sig2 = a * h/block_H + b + theta = 0 + elif sv_mode == 2: + sig1 = a * w/block_W + b + sig2 = a * h/block_H + b + theta = 0 + elif sv_mode == 3: + sig1 = a + b + sig2 = b + theta = np.pi * (h/block_H) + elif sv_mode == 4: + sig1 = a * w/block_W + b + sig2 = a * h/block_H + b + theta = np.pi * (h/block_H) + elif sv_mode == 5: + sig1 = np.random.uniform(b, a+b) + sig2 = np.random.uniform(b, a+b) + theta = np.random.uniform(0, np.pi) + elif sv_mode == 6: + sig1 = a + b + sig2 = b + if (h+w)%2 == 0: + theta = np.pi/4 + else: + theta = np.pi*3/4 + + kernel_hw, sigma_hw = stable_gaussian_kernel(l=shifted_l, sig=sig, sig1=sig1, sig2=sig2, theta=theta, + rate_iso=rate_iso, scale=scale, tensor=False) + + for m in range(divide_H): + for k in range(divide_W): + pos_h, pos_w = h*divide_H+m, w*divide_W+k + if pos_h < img_H and pos_w < img_W: + batch_kernel[ibatch, pos_h*img_W+pos_w, :shifted_l, :shifted_l], \ + batch_sigma[ibatch, pos_h*img_W+pos_w, :] = kernel_hw, sigma_hw + + if tensor: + return torch.FloatTensor(batch_kernel), torch.FloatTensor(batch_sigma) + else: + return batch_kernel, batch_sigma + + + +#################### +# bicubic downsampling +#################### + +def b_GPUVar_Bicubic(variable, scale): + tensor = variable.cpu().data + B, C, H, W = tensor.size() + H_new = int(H / scale) + W_new = int(W / scale) + tensor_view = tensor.view((B*C, 1, H, W)) + re_tensor = torch.zeros((B*C, 1, H_new, W_new)) + for i in range(B*C): + img = to_pil_image(tensor_view[i]) + re_tensor[i] = to_tensor(resize(img, (H_new, W_new), interpolation=Image.BICUBIC)) + re_tensor_view = re_tensor.view((B, C, H_new, W_new)) + return re_tensor_view + +def b_CPUVar_Bicubic(variable, scale): + tensor = variable.data + B, C, H, W = tensor.size() + H_new = int(H / scale) + W_new = int(W / scale) + tensor_v = tensor.view((B*C, 1, H, W)) + re_tensor = torch.zeros((B*C, 1, H_new, W_new)) + for i in range(B*C): + img = to_pil_image(tensor_v[i]) + re_tensor[i] = to_tensor(resize(img, (H_new, W_new), interpolation=Image.BICUBIC)) + re_tensor_v = re_tensor.view((B, C, H_new, W_new)) + return re_tensor_v + +class BatchBicubic(nn.Module): + def __init__(self, scale=4): + super(BatchBicubic, self).__init__() + self.scale = scale + + def forward(self, input): + tensor = input.cpu().data + B, C, H, W = tensor.size() + H_new = int(H / self.scale) + W_new = int(W / self.scale) + tensor_view = tensor.view((B*C, 1, H, W)) + re_tensor = torch.zeros((B*C, 1, H_new, W_new)) + for i in range(B*C): + img = to_pil_image(tensor_view[i]) + re_tensor[i] = to_tensor(resize(img, (H_new, W_new), interpolation=Image.BICUBIC)) + re_tensor_view = re_tensor.view((B, C, H_new, W_new)) + return re_tensor_view + +class BatchSubsample(nn.Module): + def __init__(self, scale=4): + super(BatchSubsample, self).__init__() + self.scale = scale + + def forward(self, input): + return input[:, :, 0::self.scale, 0::self.scale] + +#################### +# image noises +#################### + +def random_batch_noise(batch, high, rate_cln=1.0): + noise_level = np.random.uniform(size=(batch, 1)) * high + noise_mask = np.random.uniform(size=(batch, 1)) + noise_mask[noise_mask < rate_cln] = 0 + noise_mask[noise_mask >= rate_cln] = 1 + return noise_level * noise_mask + + +def b_GaussianNoising(tensor, sigma, mean=0.0, noise_size=None, min=0.0, max=1.0): + if noise_size is None: + size = tensor.size() + else: + size = noise_size + noise = torch.mul(sigma.new_tensor(np.random.normal(loc=mean, scale=1.0, size=size)), sigma.view(sigma.size() + (1, 1))) + return torch.clamp(noise + tensor, min=min, max=max) + + +#################### +# batch degradation +#################### + +class BatchSRKernel(object): + def __init__(self, l=21, sig=2.6, sig1=2.6, sig2=2.6, theta=0, sig_min=0.2, sig_max=4.0, rate_iso=1.0, scale=3): + self.l = l + self.sig = sig + self.sig1 = sig1 + self.sig2 = sig2 + self.theta = theta + self.sig_min = sig_min + self.sig_max = sig_max + self.rate_iso = rate_iso + self.scale = scale + + def __call__(self, random, batch, tensor=False): + if random == True: #random kernel + return random_batch_kernel(batch, l=self.l, sig_min=self.sig_min, sig_max=self.sig_max, rate_iso=self.rate_iso, + scale=self.scale, tensor=tensor) + else: #stable kernel + return stable_batch_kernel(batch, l=self.l, sig=self.sig, sig1=self.sig1, sig2=self.sig2, theta=self.theta, + rate_iso=self.rate_iso, scale=self.scale, tensor=tensor) + +class BatchSRKernel_SV(object): + def __init__(self, l=21, sig=2.6, sig1=2.6, sig2=2.6, theta=0, sig_min=0.2, sig_max=4.0, rate_iso=1.0, scale=3, divide_H=1, divide_W=1, sv_mode=0): + self.l = l + self.sig = sig + self.sig1 = sig1 + self.sig2 = sig2 + self.theta = theta + self.sig_min = sig_min + self.sig_max = sig_max + self.rate_iso = rate_iso + self.scale = scale + self.divide_H = divide_H + self.divide_W = divide_W + self.sv_mode = sv_mode + assert rate_iso == 0, 'only support aniso kernel at present' + + # currently only support batch=1, stable mode + def __call__(self, random, batch, img_H, img_W, tensor=False): + return stable_batch_kernel_SV_mode(batch, img_H=img_H, img_W=img_W, divide_H=self.divide_H, divide_W=self.divide_W, sv_mode=self.sv_mode, l=self.l, sig=self.sig, sig1=self.sig1, sig2=self.sig2, theta=self.theta, + rate_iso=self.rate_iso, scale=self.scale, tensor=tensor) + +class PCAEncoder(object): + def __init__(self, weight, device=torch.device('cuda')): + self.weight = weight.to(device) #[l^2, k] + self.size = self.weight.size() + + def __call__(self, batch_kernel): + B, H, W = batch_kernel.size() #[B, l, l] + return torch.bmm(batch_kernel.view((B, 1, H * W)), self.weight.expand((B, ) + self.size)).view((B, -1)) + +class PCADecoder(object): + def __init__(self, weight, device=torch.device('cuda')): + self.weight = weight.permute(1,0).to(device) #[k, l^2] + self.size = self.weight.size() + + def __call__(self, batch_kernel_map): + B, _ = batch_kernel_map.size() #[B, l, l] + return torch.bmm(batch_kernel_map.unsqueeze(1), self.weight.expand((B, ) + self.size)).view((B, int(self.size[1]**0.5), int(self.size[1]**0.5))) + +class CircularPad2d(nn.Module): + def __init__(self, pad): + super(CircularPad2d, self).__init__() + self.pad = pad + + def forward(self, input): + return F.pad(input, pad=self.pad, mode='circular') + +class BatchBlur(nn.Module): + def __init__(self, l=15, padmode='reflection'): + super(BatchBlur, self).__init__() + self.l = l + if padmode == 'reflection': + if l % 2 == 1: + self.pad = nn.ReflectionPad2d(l // 2) + else: + self.pad = nn.ReflectionPad2d((l // 2, l // 2 - 1, l // 2, l // 2 - 1)) + elif padmode == 'zero': + if l % 2 == 1: + self.pad = nn.ZeroPad2d(l // 2) + else: + self.pad = nn.ZeroPad2d((l // 2, l // 2 - 1, l // 2, l // 2 - 1)) + elif padmode == 'replication': + if l % 2 == 1: + self.pad = nn.ReplicationPad2d(l // 2) + else: + self.pad = nn.ReplicationPad2d((l // 2, l // 2 - 1, l // 2, l // 2 - 1)) + elif padmode == 'circular': + if l % 2 == 1: + self.pad = CircularPad2d((l // 2, l // 2, l // 2, l // 2)) + else: + self.pad = CircularPad2d((l // 2, l // 2 - 1, l // 2, l // 2 - 1)) + else: + raise NotImplementedError + + def forward(self, input, kernel): + B, C, H, W = input.size() + pad = self.pad(input) + H_p, W_p = pad.size()[-2:] + + if len(kernel.size()) == 2: + input_CBHW = pad.view((C * B, 1, H_p, W_p)) + kernel_var = kernel.contiguous().view((1, 1, self.l, self.l)) + return F.conv2d(input_CBHW, kernel_var, padding=0).view((B, C, H, W)) + else: + input_CBHW = pad.view((1, C * B, H_p, W_p)) + kernel_var = kernel.contiguous().view((B, 1, self.l, self.l)).repeat(1, C, 1, 1).view((B * C, 1, self.l, self.l)) + return F.conv2d(input_CBHW, kernel_var, groups=B*C).view((B, C, H, W)) + + +# spatially variant blur +class BatchBlur_SV(nn.Module): + def __init__(self, l=15, padmode='reflection'): + super(BatchBlur_SV, self).__init__() + self.l = l + if padmode == 'reflection': + if l % 2 == 1: + self.pad = nn.ReflectionPad2d(l // 2) + else: + self.pad = nn.ReflectionPad2d((l // 2, l // 2 - 1, l // 2, l // 2 - 1)) + elif padmode == 'zero': + if l % 2 == 1: + self.pad = nn.ZeroPad2d(l // 2) + else: + self.pad = nn.ZeroPad2d((l // 2, l // 2 - 1, l // 2, l // 2 - 1)) + elif padmode == 'replication': + if l % 2 == 1: + self.pad = nn.ReplicationPad2d(l // 2) + else: + self.pad = nn.ReplicationPad2d((l // 2, l // 2 - 1, l // 2, l // 2 - 1)) + elif padmode == 'circular': + if l % 2 == 1: + self.pad = CircularPad2d((l // 2, l // 2, l // 2, l // 2)) + else: + self.pad = CircularPad2d((l // 2, l // 2 - 1, l // 2, l // 2 - 1)) + + def forward(self, input, kernel): + # kernel of size [N,Himage*Wimage,H,W] + B, C, H, W = input.size() + pad = self.pad(input) + H_p, W_p = pad.size()[-2:] + + if len(kernel.size()) == 2: + input_CBHW = pad.view((C * B, 1, H_p, W_p)) + kernel_var = kernel.contiguous().view((1, 1, self.l, self.l)) + return F.conv2d(input_CBHW, kernel_var, padding=0).view((B, C, H, W)) + else: + pad = pad.view(C * B, 1, H_p, W_p) + pad = F.unfold(pad, self.l).transpose(1, 2) # [CB, HW, k^2] + kernel = kernel.flatten(2).unsqueeze(0).expand(3,-1,-1,-1) + out_unf = (pad*kernel.contiguous().view(-1,kernel.size(2),kernel.size(3))).sum(2).unsqueeze(1) + out = F.fold(out_unf, (H, W), 1).view(B, C, H, W) + + return out + +class SRMDPreprocessing(object): + def __init__(self, scale, random, l=21, add_noise=False, device=torch.device('cuda'), sig=2.6, sig1=2.6, sig2=2.6, theta=0, + sig_min=0.2, sig_max=4.0, rate_iso=1.0, rate_cln=0.2, noise_high=0.05882, is_training=False, sv_mode=0): + + self.device = device + self.l = l + self.noise = add_noise + self.noise_high = noise_high + self.rate_cln = rate_cln + self.scale = scale + self.random = random + self.rate_iso = rate_iso + self.is_training = is_training + self.sv_mode = sv_mode + + if self.sv_mode == 0: # spatial-variant + self.blur = BatchBlur(l=l, padmode='replication') + self.kernel_gen = BatchSRKernel(l=l, sig=sig, sig1=sig1, sig2=sig2, theta=theta, + sig_min=sig_min, sig_max=sig_max, rate_iso=rate_iso, scale=scale) + else: # spatial-invariant + self.blur = BatchBlur_SV(l=l, padmode='replication') + self.kernel_gen = BatchSRKernel_SV(l=l, sig=sig, sig1=sig1, sig2=sig2, theta=theta, + sig_min=sig_min, sig_max=sig_max, rate_iso=rate_iso, scale=scale, + divide_H=40, divide_W=40, sv_mode=sv_mode) + self.sample = BatchSubsample(scale=scale) + + + def __call__(self, hr_tensor, kernel=False): + B, C, H, W = hr_tensor.size() + + # generate kernel + if self.sv_mode == 0: + b_kernels, b_sigmas = self.kernel_gen(self.random, B, tensor=True) + else: + b_kernels, b_sigmas = self.kernel_gen(self.random, B, H, W, tensor=True) + b_kernels, b_sigmas = b_kernels.to(self.device) , b_sigmas.to(self.device) + + # blur and downsample + lr = self.sample(self.blur(hr_tensor, b_kernels)) + lr_n = lr + + # Gaussian noise + if self.noise: + if self.is_training: + Noise_level = torch.FloatTensor(random_batch_noise(B, self.noise_high, self.rate_cln)).to(self.device) + else: + Noise_level = (torch.ones(B, 1)*self.noise_high).to(self.device) + lr_n = b_GaussianNoising(lr_n, Noise_level) + if len(b_sigmas.size()) == 2: # only concat for spatially invariant kernel + b_sigmas = torch.cat([b_sigmas, Noise_level * 10], dim=1) + + # image quantization + lr = (lr * 255.).round()/255. + lr_n = (lr_n * 255.).round()/255. + + return (lr, lr_n, b_sigmas, b_kernels) if kernel else (lr, lr_n, b_sigmas) + + +#################### +# miscellaneous +#################### + + +def get_timestamp(): + return datetime.now().strftime('%y%m%d-%H%M%S') + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def mkdirs(paths): + if isinstance(paths, str): + mkdir(paths) + else: + for path in paths: + mkdir(path) + + +def mkdir_and_rename(path): + if os.path.exists(path): + new_name = path + '_archived_' + get_timestamp() + print('Path already exists. Rename it to [{:s}]'.format(new_name)) + logger = logging.getLogger('base') + logger.info('Path already exists. Rename it to [{:s}]'.format(new_name)) + os.rename(path, new_name) + os.makedirs(path) + + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): + '''set up logger''' + lg = logging.getLogger(logger_name) + formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', + datefmt='%y-%m-%d %H:%M:%S') + lg.setLevel(level) + if tofile: + log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp())) + fh = logging.FileHandler(log_file, mode='w') + fh.setFormatter(formatter) + lg.addHandler(fh) + if screen: + sh = logging.StreamHandler() + sh.setFormatter(formatter) + lg.addHandler(sh) + + +#################### +# image convert +#################### + + +def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): + ''' + Converts a torch Tensor into an image Numpy array + Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default), BGR channel order + ''' + if hasattr(tensor, 'detach'): + tensor = tensor.detach() + tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] + n_dim = tensor.dim() + if n_dim == 4: + n_img = len(tensor) + img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 3: + img_np = tensor.numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 2: + img_np = tensor.numpy() + else: + raise TypeError( + 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + if out_type == np.uint8: + img_np = (img_np * 255.0).round() + # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. + return img_np.astype(out_type) + + +def save_img(img, img_path, mode='RGB'): + cv2.imwrite(img_path, img) + +def img2tensor(img): + ''' + # BGR to RGB, HWC to CHW, numpy to tensor + Input: img(H, W, C), [0,255], np.uint8 (default) + Output: 3D(C,H,W), RGB order, float tensor + ''' + img = img.astype(np.float32) / 255. + img = img[:, :, [2, 1, 0]] + img = torch.from_numpy(np.ascontiguousarray(np.transpose(img, (2, 0, 1)))).float() + return img + + +def DUF_downsample(x, scale=4): + """Downsamping with Gaussian kernel used in the DUF official code + + Args: + x (Tensor, [B, T, C, H, W]): frames to be downsampled. + scale (int): downsampling factor: 2 | 3 | 4. + """ + + assert scale in [2, 3, 4], 'Scale [{}] is not supported'.format(scale) + + def gkern(kernlen=13, nsig=1.6): + import scipy.ndimage.filters as fi + inp = np.zeros((kernlen, kernlen)) + # set element at the middle to one, a dirac delta + inp[kernlen // 2, kernlen // 2] = 1 + # gaussian-smooth the dirac, resulting in a gaussian filter mask + return fi.gaussian_filter(inp, nsig) + + B, T, C, H, W = x.size() + x = x.view(-1, 1, H, W) + pad_w, pad_h = 6 + scale * 2, 6 + scale * 2 # 6 is the pad of the gaussian filter + r_h, r_w = 0, 0 + if scale == 3: + r_h = 3 - (H % 3) + r_w = 3 - (W % 3) + x = F.pad(x, [pad_w, pad_w + r_w, pad_h, pad_h + r_h], 'reflect') + + gaussian_filter = torch.from_numpy(gkern(13, 0.4 * scale)).type_as(x).unsqueeze(0).unsqueeze(0) + x = F.conv2d(x, gaussian_filter, stride=scale) + x = x[:, :, 2:-2, 2:-2] + x = x.view(B, T, C, x.size(2), x.size(3)) + return x + +#################### +# metric +#################### + +def calculate_mnc(img1, img2): + # img1 and img2 have range [0, 255] + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + img2 = img2/np.sqrt(np.sum(img2**2)) + import scipy.signal as signal + temp = signal.convolve2d(img2, img1, 'full') + temp2 = np.sqrt(np.sum(img1**2)) + return np.max(temp)/temp2 + +def calculate_kernel_psnr(img1, img2, border=0): + # img1 and img2 have range [0, 255] + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(1.0 / math.sqrt(mse)) + +def calculate_psnr(img1, img2): + # img1 and img2 have range [0, 255] + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + +def calculate_mse(img1, img2): + # img1 and img2 have range [0, 255] + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + return mse + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +def calculate_ssim(img1, img2): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +def calculate_psnr_ssim(img1, img2, crop_border=0): + if crop_border == 0: + cropped_img1 = img1 + cropped_img2 = img2 + else: + cropped_img1 = img1[crop_border:-crop_border, crop_border:-crop_border] + cropped_img2 = img2[crop_border:-crop_border, crop_border:-crop_border] + psnr = calculate_psnr(cropped_img1 * 255, cropped_img2 * 255) + ssim = calculate_ssim(cropped_img1 * 255, cropped_img2 * 255) + + if img2.shape[2] == 3: # RGB image + img1_y = bgr2ycbcr(img1, only_y=True) + img2_y = bgr2ycbcr(img2, only_y=True) + if crop_border == 0: + cropped_img1_y = img1_y + cropped_img2_y = img2_y + else: + cropped_img1_y = img1_y[crop_border:-crop_border, crop_border:-crop_border] + cropped_img2_y = img2_y[crop_border:-crop_border, crop_border:-crop_border] + psnr_y = calculate_psnr(cropped_img1_y * 255, cropped_img2_y * 255) + ssim_y = calculate_ssim(cropped_img1_y * 255, cropped_img2_y * 255) + else: + psnr_y, ssim_y = 0, 0 + + return psnr, ssim, psnr_y, ssim_y + + +class ProgressBar(object): + '''A progress bar which can print the progress + modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py + ''' + + def __init__(self, task_num=0, bar_width=50, start=True): + self.task_num = task_num + max_bar_width = self._get_max_bar_width() + self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width) + self.completed = 0 + if start: + self.start() + + def _get_max_bar_width(self): + terminal_width, _ = get_terminal_size() + max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50) + if max_bar_width < 10: + print('terminal width is too small ({}), please consider widen the terminal for better ' + 'progressbar visualization'.format(terminal_width)) + max_bar_width = 10 + return max_bar_width + + def start(self): + if self.task_num > 0: + sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format( + ' ' * self.bar_width, self.task_num, 'Start...')) + else: + sys.stdout.write('completed: 0, elapsed: 0s') + sys.stdout.flush() + self.start_time = time.time() + + def update(self, msg='In progress...'): + self.completed += 1 + elapsed = time.time() - self.start_time + fps = self.completed / elapsed + if self.task_num > 0: + percentage = self.completed / float(self.task_num) + eta = int(elapsed * (1 - percentage) / percentage + 0.5) + mark_width = int(self.bar_width * percentage) + bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width) + sys.stdout.write('\033[2F') # cursor up 2 lines + sys.stdout.write('\033[J') # clean the output (remove extra chars since last display) + sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format( + bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg)) + else: + sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format( + self.completed, int(elapsed + 0.5), fps)) + sys.stdout.flush() + +#################### +# for debug +#################### +def surf(Z, cmap='rainbow', figsize=None): + plt.figure(figsize=figsize) + ax3 = plt.axes(projection='3d') + + w, h = Z.shape[:2] + xx = np.arange(0,w,1) + yy = np.arange(0,h,1) + X, Y = np.meshgrid(xx, yy) + ax3.plot_surface(X,Y,Z,cmap=cmap) + #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) + plt.show() + plt.savefig('/home/jinliang/Downloads/tmp.png') + +def imagesc(Z): + f, ax = plt.subplots(1, 1, squeeze=False) + im = ax[0,0].imshow(Z, vmin=0, vmax=Z.max()) + plt.colorbar(im, ax=ax[0,0]) + plt.show() + plt.savefig('/home/jinliang/Downloads/tmp.png') + + +# copyed from data.util +def bgr2ycbcr(img, only_y=True): + '''bgr version of rgb2ycbcr, following matlab version instead of opencv + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def compute_RF_numerical(net, img_np, re_init_para=False): + ''' + https://github.com/rogertrullo/Receptive-Field-in-Pytorch/blob/master/Receptive_Field.ipynb + @param net: Pytorch network + @param img_np: numpy array to use as input to the networks, it must be full of ones and with the correct + shape. + ''' + def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + m.weight.data.fill_(1) + m.bias.data.fill_(0) + if re_init_para: + net.apply(weights_init) + + img_ = Variable(torch.from_numpy(img_np).float().cuda(),requires_grad=True) + out_cnn=net(img_) # here we have two inputs and two outputs + out_shape=out_cnn.size() + ndims=len(out_cnn.size()) + grad=torch.zeros(out_cnn.size()).cuda() + l_tmp=[] + for i in range(ndims): + if i==0 or i ==1:#batch or channel + l_tmp.append(0) + else: + l_tmp.append(int(out_shape[i]/2)) + + grad[tuple(l_tmp)]=1 + out_cnn.backward(gradient=grad) + grad_np=img_.grad[0,0].data.cpu().numpy() + idx_nonzeros=np.where(grad_np!=0) + RF=[np.max(idx)-np.min(idx)+1 for idx in idx_nonzeros] + + return RF + +def plot_kernel(out_k_np, savepath, gt_k_np=None): + plt.clf() + if gt_k_np is None: + ax = plt.subplot(111) + im = ax.imshow(out_k_np, vmin=out_k_np.min(), vmax=out_k_np.max()) + plt.colorbar(im, ax=ax) + else: + + ax = plt.subplot(121) + im = ax.imshow(gt_k_np, vmin=gt_k_np.min(), vmax=gt_k_np.max()) + plt.colorbar(im, ax=ax) + ax.set_title('GT Kernel') + + ax = plt.subplot(122) + im = ax.imshow(out_k_np, vmin=gt_k_np.min(), vmax=gt_k_np.max()) + plt.colorbar(im, ax=ax) + ax.set_title('Kernel PSNR: {:.2f}'.format(calculate_kernel_psnr(out_k_np, gt_k_np))) + + plt.show() + plt.savefig(savepath) + +def get_resume_paths(opt): + resume_state_path = None + resume_model_path = None + if opt.get('path', {}).get('resume_state', None) == "auto": + wildcard = os.path.join(opt['path']['training_state'], "*") + paths = natsort.natsorted(glob.glob(wildcard)) + if len(paths) > 0: + resume_state_path = paths[-1] + resume_model_path = resume_state_path.replace('training_state', 'models').replace('.state', '_G.pth') + else: + resume_state_path = opt.get('path', {}).get('resume_state') + return resume_state_path, resume_model_path + + +def opt_get(opt, keys, default=None): + if opt is None: + return default + ret = opt + for k in keys: + ret = ret.get(k, None) + if ret is None: + return default + return ret + + +def get_printer(msg): + """This function returns a printer function, that prints information about a tensor's + gradient. Used by register_hook in the backward pass. + """ + def printer(tensor): + if tensor.nelement() == 1: + print(f"{msg} {tensor}") + else: + print(f"{msg} shape: {tensor.shape}" + f" max_grad: {tensor.max()} min_grad: {tensor.min()}" + f" mean_grad: {tensor.mean()}") + return printer + + +def register_hook(tensor, msg=''): + """Utility function to call retain_grad and Pytorch's register_hook + in a single line, to get the gradient of a variable in debugging + """ + print(f"{msg} shape: {tensor.shape}" + f" max_value: {tensor.max()} min_value: {tensor.min()}" + f" mean_value: {tensor.mean()}") + tensor.retain_grad() + tensor.register_hook(get_printer(msg)) + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + diff --git a/datasets/example_face_4X/HR/example0.png b/datasets/example_face_4X/HR/example0.png new file mode 100644 index 0000000..b41c246 Binary files /dev/null and b/datasets/example_face_4X/HR/example0.png differ diff --git a/datasets/example_face_4X/LR/example0.png b/datasets/example_face_4X/LR/example0.png new file mode 100644 index 0000000..72f9a31 Binary files /dev/null and b/datasets/example_face_4X/LR/example0.png differ diff --git a/datasets/example_general_4X/HR/butterfly.png b/datasets/example_general_4X/HR/butterfly.png new file mode 100644 index 0000000..f6acb9a Binary files /dev/null and b/datasets/example_general_4X/HR/butterfly.png differ diff --git a/datasets/example_general_4X/LR/butterfly.png b/datasets/example_general_4X/LR/butterfly.png new file mode 100644 index 0000000..b55686b Binary files /dev/null and b/datasets/example_general_4X/LR/butterfly.png differ diff --git a/illustrations/architecture.png b/illustrations/architecture.png new file mode 100644 index 0000000..39a44c0 Binary files /dev/null and b/illustrations/architecture.png differ diff --git a/illustrations/computation_graph.png b/illustrations/computation_graph.png new file mode 100644 index 0000000..d974166 Binary files /dev/null and b/illustrations/computation_graph.png differ diff --git a/illustrations/face_result.png b/illustrations/face_result.png new file mode 100644 index 0000000..a911faa Binary files /dev/null and b/illustrations/face_result.png differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..a142a0d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,82 @@ +appnope==0.1.0 +argon2-cffi==20.1.0 +async-generator==1.10 +attrs==20.2.0 +backcall==0.2.0 +bleach==3.2.1 +certifi==2020.6.20 +cffi==1.14.3 +cycler==0.10.0 +dataclasses==0.6 +decorator==4.4.2 +defusedxml==0.6.0 +entrypoints==0.3 +environment-kernels==1.1.1 +future==0.18.2 +imageio==2.9.0 +importlib-metadata==2.0.0 +ipykernel==5.3.4 +ipython==7.19.0 +ipython-genutils==0.2.0 +ipywidgets==7.5.1 +jedi==0.17.2 +Jinja2==2.11.2 +jsonschema==3.2.0 +jupyter==1.0.0 +jupyter-client==6.1.7 +jupyter-console==6.2.0 +jupyter-core==4.6.3 +jupyterlab-pygments==0.1.2 +kiwisolver==1.3.1 +lpips==0.1.3 +MarkupSafe==1.1.1 +matplotlib==3.3.2 +mistune==0.8.4 +natsort==7.0.1 +nbclient==0.5.1 +nbconvert==6.0.7 +nbformat==5.0.8 +nest-asyncio==1.4.2 +networkx==2.5 +notebook==6.1.4 +numpy==1.19.4 +opencv-python==4.4.0.46 +packaging==20.4 +pandas==1.1.4 +pandocfilters==1.4.3 +parso==0.7.1 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==8.0.1 +prometheus-client==0.8.0 +prompt-toolkit==3.0.8 +ptyprocess==0.6.0 +pycparser==2.20 +Pygments==2.7.2 +pyparsing==2.4.7 +pyrsistent==0.17.3 +python-dateutil==2.8.1 +pytz==2020.4 +PyWavelets==1.1.1 +PyYAML==5.3.1 +pyzmq==19.0.2 +qtconsole==4.7.7 +QtPy==1.9.0 +scikit-image==0.17.2 +scipy==1.5.3 +Send2Trash==1.5.0 +six==1.15.0 +terminado==0.9.1 +tensorboard==2.4.0 +testpath==0.4.4 +tifffile==2020.10.1 +torch==1.7.1 +torchvision==0.8.1 +tornado==6.1 +tqdm==4.51.0 +traitlets==5.0.5 +typing-extensions==3.7.4.3 +wcwidth==0.2.5 +webencodings==0.5.1 +widgetsnbextension==3.5.1 +zipp==3.4.0