diff --git a/README.md b/README.md index 80b425a7..3827e357 100755 --- a/README.md +++ b/README.md @@ -3,12 +3,12 @@



# pix2pixHD -### [[Project]](https://tcwang0509.github.io/pix2pixHD/) [[Youtube]](https://youtu.be/3AIpPlzM_qs) [[Paper]](https://arxiv.org/pdf/1711.11585.pdf)
+### [[Project]](https://tcwang0509.github.io/pix2pixHD/) [[Youtube]](https://youtu.be/3AIpPlzM_qs) [[Paper]](https://arxiv.org/pdf/1711.11585.pdf) [[Slides]](https://drive.google.com/open?id=1pObGZvEgPILtRsqEP07sv8CtNuQ55xur)
Pytorch implementation of our method for high-resolution (e.g. 2048x1024) photorealistic image-to-image translation. It can be used for turning semantic label maps into photo-realistic images or synthesizing portraits from face label maps.

[High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs](https://tcwang0509.github.io/pix2pixHD/) [Ting-Chun Wang](https://tcwang0509.github.io/)1, [Ming-Yu Liu](http://mingyuliu.net/)1, [Jun-Yan Zhu](http://people.eecs.berkeley.edu/~junyanz/)2, Andrew Tao1, [Jan Kautz](http://jankautz.com/)1, [Bryan Catanzaro](http://catanzaro.name/)1 1NVIDIA Corporation, 2UC Berkeley - In arxiv, 2017. + In CVPR, 2018. ## Image-to-image translation at 2k/1k resolution - Our label-to-streetview results @@ -123,11 +123,11 @@ If only GPUs with 12G memory are available, please use the 12G script (`bash ./s If you find this useful for your research, please use the following. ``` -@article{wang2017highres, +@inproceedings{wang2018pix2pixHD, title={High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs}, - author={Ting-Chun Wang and Ming-Yu Liu and Jun-Yan Zhu and Andrew Tao and Jan Kautz and Bryan Catanzaro}, - journal={arXiv preprint arXiv:1711.11585}, - year={2017} + author={Ting-Chun Wang and Ming-Yu Liu and Jun-Yan Zhu and Andrew Tao and Jan Kautz and Bryan Catanzaro}, + booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, + year={2018} } ``` diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py index 81d54c49..6e01a012 100755 --- a/data/aligned_dataset.py +++ b/data/aligned_dataset.py @@ -16,7 +16,7 @@ def initialize(self, opt): self.A_paths = sorted(make_dataset(self.dir_A)) ### input B (real images) - if opt.isTrain: + if opt.isTrain or opt.use_encoded_image: dir_B = '_B' if self.opt.label_nc == 0 else '_img' self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B) self.B_paths = sorted(make_dataset(self.dir_B)) @@ -48,7 +48,7 @@ def __getitem__(self, index): B_tensor = inst_tensor = feat_tensor = 0 ### input B (real images) - if self.opt.isTrain: + if self.opt.isTrain or self.opt.use_encoded_image: B_path = self.B_paths[index] B = Image.open(B_path).convert('RGB') transform_B = get_transform(self.opt, params) diff --git a/encode_features.py b/encode_features.py index 84372644..3adfedd8 100755 --- a/encode_features.py +++ b/encode_features.py @@ -12,6 +12,7 @@ opt.serial_batches = True opt.no_flip = True opt.instance_feat = True +opt.continue_train = True name = 'features' save_path = os.path.join(opt.checkpoints_dir, opt.name) diff --git a/models/pix2pixHD_model.py b/models/pix2pixHD_model.py index 57c18b6c..851460e1 100755 --- a/models/pix2pixHD_model.py +++ b/models/pix2pixHD_model.py @@ -126,7 +126,7 @@ def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, if not self.opt.no_instance: inst_map = inst_map.data.cuda() edge_map = self.get_edges(inst_map) - input_label = torch.cat((input_label, edge_map), dim=1) + input_label = torch.cat((input_label, edge_map), dim=1) input_label = Variable(input_label, volatile=infer) # real images for training @@ -138,6 +138,8 @@ def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, # get precomputed feature maps if self.opt.load_features: feat_map = Variable(feat_map.data.cuda()) + if self.opt.label_feat: + inst_map = label_map.cuda() return input_label, inst_map, real_image, feat_map @@ -192,14 +194,19 @@ def forward(self, label, inst, image, feat, infer=False): # Only return the fake_B image if necessary to save BW return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ), None if not infer else fake_image ] - def inference(self, label, inst): + def inference(self, label, inst, image=None): # Encode Inputs - input_label, inst_map, _, _ = self.encode_input(Variable(label), Variable(inst), infer=True) + image = Variable(image) if image is not None else None + input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True) # Fake Generation - if self.use_features: - # sample clusters from precomputed features - feat_map = self.sample_features(inst_map) + if self.use_features: + if self.opt.use_encoded_image: + # encode the real image to get feature map + feat_map = self.netE.forward(real_image, inst_map) + else: + # sample clusters from precomputed features + feat_map = self.sample_features(inst_map) input_concat = torch.cat((input_label, feat_map), dim=1) else: input_concat = input_label @@ -214,7 +221,7 @@ def inference(self, label, inst): def sample_features(self, inst): # read precomputed feature clusters cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path) - features_clustered = np.load(cluster_path).item() + features_clustered = np.load(cluster_path, encoding='latin1').item() # randomly sample from the feature clusters inst_np = inst.cpu().numpy().astype(int) diff --git a/options/test_options.py b/options/test_options.py index 29be0370..32d926a5 100755 --- a/options/test_options.py +++ b/options/test_options.py @@ -12,6 +12,7 @@ def initialize(self): self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run') self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features') + self.parser.add_argument('--use_encoded_image', action='store_true', help='if specified, encode the real image to get the feature map') self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file") self.parser.add_argument("--engine", type=str, help="run serialized TRT engine") self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT") diff --git a/test.py b/test.py index b7f6dbc5..da1312dd 100755 --- a/test.py +++ b/test.py @@ -58,7 +58,7 @@ elif opt.onnx: generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']]) else: - generated = model.inference(data['label'], data['inst']) + generated = model.inference(data['label'], data['inst'], data['image']) visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)), ('synthesized_image', util.tensor2im(generated.data[0]))])