diff --git a/README.md b/README.md
index 80b425a7..3827e357 100755
--- a/README.md
+++ b/README.md
@@ -3,12 +3,12 @@
# pix2pixHD
-### [[Project]](https://tcwang0509.github.io/pix2pixHD/) [[Youtube]](https://youtu.be/3AIpPlzM_qs) [[Paper]](https://arxiv.org/pdf/1711.11585.pdf)
+### [[Project]](https://tcwang0509.github.io/pix2pixHD/) [[Youtube]](https://youtu.be/3AIpPlzM_qs) [[Paper]](https://arxiv.org/pdf/1711.11585.pdf) [[Slides]](https://drive.google.com/open?id=1pObGZvEgPILtRsqEP07sv8CtNuQ55xur)
Pytorch implementation of our method for high-resolution (e.g. 2048x1024) photorealistic image-to-image translation. It can be used for turning semantic label maps into photo-realistic images or synthesizing portraits from face label maps.
[High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs](https://tcwang0509.github.io/pix2pixHD/)
[Ting-Chun Wang](https://tcwang0509.github.io/)1, [Ming-Yu Liu](http://mingyuliu.net/)1, [Jun-Yan Zhu](http://people.eecs.berkeley.edu/~junyanz/)2, Andrew Tao1, [Jan Kautz](http://jankautz.com/)1, [Bryan Catanzaro](http://catanzaro.name/)1
1NVIDIA Corporation, 2UC Berkeley
- In arxiv, 2017.
+ In CVPR, 2018.
## Image-to-image translation at 2k/1k resolution
- Our label-to-streetview results
@@ -123,11 +123,11 @@ If only GPUs with 12G memory are available, please use the 12G script (`bash ./s
If you find this useful for your research, please use the following.
```
-@article{wang2017highres,
+@inproceedings{wang2018pix2pixHD,
title={High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs},
- author={Ting-Chun Wang and Ming-Yu Liu and Jun-Yan Zhu and Andrew Tao and Jan Kautz and Bryan Catanzaro},
- journal={arXiv preprint arXiv:1711.11585},
- year={2017}
+ author={Ting-Chun Wang and Ming-Yu Liu and Jun-Yan Zhu and Andrew Tao and Jan Kautz and Bryan Catanzaro},
+ booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
+ year={2018}
}
```
diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py
index 81d54c49..6e01a012 100755
--- a/data/aligned_dataset.py
+++ b/data/aligned_dataset.py
@@ -16,7 +16,7 @@ def initialize(self, opt):
self.A_paths = sorted(make_dataset(self.dir_A))
### input B (real images)
- if opt.isTrain:
+ if opt.isTrain or opt.use_encoded_image:
dir_B = '_B' if self.opt.label_nc == 0 else '_img'
self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B)
self.B_paths = sorted(make_dataset(self.dir_B))
@@ -48,7 +48,7 @@ def __getitem__(self, index):
B_tensor = inst_tensor = feat_tensor = 0
### input B (real images)
- if self.opt.isTrain:
+ if self.opt.isTrain or self.opt.use_encoded_image:
B_path = self.B_paths[index]
B = Image.open(B_path).convert('RGB')
transform_B = get_transform(self.opt, params)
diff --git a/encode_features.py b/encode_features.py
index 84372644..3adfedd8 100755
--- a/encode_features.py
+++ b/encode_features.py
@@ -12,6 +12,7 @@
opt.serial_batches = True
opt.no_flip = True
opt.instance_feat = True
+opt.continue_train = True
name = 'features'
save_path = os.path.join(opt.checkpoints_dir, opt.name)
diff --git a/models/pix2pixHD_model.py b/models/pix2pixHD_model.py
index 57c18b6c..851460e1 100755
--- a/models/pix2pixHD_model.py
+++ b/models/pix2pixHD_model.py
@@ -126,7 +126,7 @@ def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None,
if not self.opt.no_instance:
inst_map = inst_map.data.cuda()
edge_map = self.get_edges(inst_map)
- input_label = torch.cat((input_label, edge_map), dim=1)
+ input_label = torch.cat((input_label, edge_map), dim=1)
input_label = Variable(input_label, volatile=infer)
# real images for training
@@ -138,6 +138,8 @@ def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None,
# get precomputed feature maps
if self.opt.load_features:
feat_map = Variable(feat_map.data.cuda())
+ if self.opt.label_feat:
+ inst_map = label_map.cuda()
return input_label, inst_map, real_image, feat_map
@@ -192,14 +194,19 @@ def forward(self, label, inst, image, feat, infer=False):
# Only return the fake_B image if necessary to save BW
return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ), None if not infer else fake_image ]
- def inference(self, label, inst):
+ def inference(self, label, inst, image=None):
# Encode Inputs
- input_label, inst_map, _, _ = self.encode_input(Variable(label), Variable(inst), infer=True)
+ image = Variable(image) if image is not None else None
+ input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True)
# Fake Generation
- if self.use_features:
- # sample clusters from precomputed features
- feat_map = self.sample_features(inst_map)
+ if self.use_features:
+ if self.opt.use_encoded_image:
+ # encode the real image to get feature map
+ feat_map = self.netE.forward(real_image, inst_map)
+ else:
+ # sample clusters from precomputed features
+ feat_map = self.sample_features(inst_map)
input_concat = torch.cat((input_label, feat_map), dim=1)
else:
input_concat = input_label
@@ -214,7 +221,7 @@ def inference(self, label, inst):
def sample_features(self, inst):
# read precomputed feature clusters
cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path)
- features_clustered = np.load(cluster_path).item()
+ features_clustered = np.load(cluster_path, encoding='latin1').item()
# randomly sample from the feature clusters
inst_np = inst.cpu().numpy().astype(int)
diff --git a/options/test_options.py b/options/test_options.py
index 29be0370..32d926a5 100755
--- a/options/test_options.py
+++ b/options/test_options.py
@@ -12,6 +12,7 @@ def initialize(self):
self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run')
self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features')
+ self.parser.add_argument('--use_encoded_image', action='store_true', help='if specified, encode the real image to get the feature map')
self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file")
self.parser.add_argument("--engine", type=str, help="run serialized TRT engine")
self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT")
diff --git a/test.py b/test.py
index b7f6dbc5..da1312dd 100755
--- a/test.py
+++ b/test.py
@@ -58,7 +58,7 @@
elif opt.onnx:
generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']])
else:
- generated = model.inference(data['label'], data['inst'])
+ generated = model.inference(data['label'], data['inst'], data['image'])
visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
('synthesized_image', util.tensor2im(generated.data[0]))])