Skip to content

Commit

Permalink
fix feature encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
tcwang0509 committed Nov 2, 2018
1 parent 20687df commit e50d03e
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 16 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
<br><br><br><br>

# pix2pixHD
### [[Project]](https://tcwang0509.github.io/pix2pixHD/) [[Youtube]](https://youtu.be/3AIpPlzM_qs) [[Paper]](https://arxiv.org/pdf/1711.11585.pdf) <br>
### [[Project]](https://tcwang0509.github.io/pix2pixHD/) [[Youtube]](https://youtu.be/3AIpPlzM_qs) [[Paper]](https://arxiv.org/pdf/1711.11585.pdf) [[Slides]](https://drive.google.com/open?id=1pObGZvEgPILtRsqEP07sv8CtNuQ55xur) <br>
Pytorch implementation of our method for high-resolution (e.g. 2048x1024) photorealistic image-to-image translation. It can be used for turning semantic label maps into photo-realistic images or synthesizing portraits from face label maps. <br><br>
[High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs](https://tcwang0509.github.io/pix2pixHD/)
[Ting-Chun Wang](https://tcwang0509.github.io/)<sup>1</sup>, [Ming-Yu Liu](http://mingyuliu.net/)<sup>1</sup>, [Jun-Yan Zhu](http://people.eecs.berkeley.edu/~junyanz/)<sup>2</sup>, Andrew Tao<sup>1</sup>, [Jan Kautz](http://jankautz.com/)<sup>1</sup>, [Bryan Catanzaro](http://catanzaro.name/)<sup>1</sup>
<sup>1</sup>NVIDIA Corporation, <sup>2</sup>UC Berkeley
In arxiv, 2017.
In CVPR, 2018.

## Image-to-image translation at 2k/1k resolution
- Our label-to-streetview results
Expand Down Expand Up @@ -123,11 +123,11 @@ If only GPUs with 12G memory are available, please use the 12G script (`bash ./s
If you find this useful for your research, please use the following.

```
@article{wang2017highres,
@inproceedings{wang2018pix2pixHD,
title={High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs},
author={Ting-Chun Wang and Ming-Yu Liu and Jun-Yan Zhu and Andrew Tao and Jan Kautz and Bryan Catanzaro},
journal={arXiv preprint arXiv:1711.11585},
year={2017}
author={Ting-Chun Wang and Ming-Yu Liu and Jun-Yan Zhu and Andrew Tao and Jan Kautz and Bryan Catanzaro},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
year={2018}
}
```

Expand Down
4 changes: 2 additions & 2 deletions data/aligned_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def initialize(self, opt):
self.A_paths = sorted(make_dataset(self.dir_A))

### input B (real images)
if opt.isTrain:
if opt.isTrain or opt.use_encoded_image:
dir_B = '_B' if self.opt.label_nc == 0 else '_img'
self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B)
self.B_paths = sorted(make_dataset(self.dir_B))
Expand Down Expand Up @@ -48,7 +48,7 @@ def __getitem__(self, index):

B_tensor = inst_tensor = feat_tensor = 0
### input B (real images)
if self.opt.isTrain:
if self.opt.isTrain or self.opt.use_encoded_image:
B_path = self.B_paths[index]
B = Image.open(B_path).convert('RGB')
transform_B = get_transform(self.opt, params)
Expand Down
1 change: 1 addition & 0 deletions encode_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
opt.serial_batches = True
opt.no_flip = True
opt.instance_feat = True
opt.continue_train = True

name = 'features'
save_path = os.path.join(opt.checkpoints_dir, opt.name)
Expand Down
21 changes: 14 additions & 7 deletions models/pix2pixHD_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None,
if not self.opt.no_instance:
inst_map = inst_map.data.cuda()
edge_map = self.get_edges(inst_map)
input_label = torch.cat((input_label, edge_map), dim=1)
input_label = torch.cat((input_label, edge_map), dim=1)
input_label = Variable(input_label, volatile=infer)

# real images for training
Expand All @@ -138,6 +138,8 @@ def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None,
# get precomputed feature maps
if self.opt.load_features:
feat_map = Variable(feat_map.data.cuda())
if self.opt.label_feat:
inst_map = label_map.cuda()

return input_label, inst_map, real_image, feat_map

Expand Down Expand Up @@ -192,14 +194,19 @@ def forward(self, label, inst, image, feat, infer=False):
# Only return the fake_B image if necessary to save BW
return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ), None if not infer else fake_image ]

def inference(self, label, inst):
def inference(self, label, inst, image=None):
# Encode Inputs
input_label, inst_map, _, _ = self.encode_input(Variable(label), Variable(inst), infer=True)
image = Variable(image) if image is not None else None
input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True)

# Fake Generation
if self.use_features:
# sample clusters from precomputed features
feat_map = self.sample_features(inst_map)
if self.use_features:
if self.opt.use_encoded_image:
# encode the real image to get feature map
feat_map = self.netE.forward(real_image, inst_map)
else:
# sample clusters from precomputed features
feat_map = self.sample_features(inst_map)
input_concat = torch.cat((input_label, feat_map), dim=1)
else:
input_concat = input_label
Expand All @@ -214,7 +221,7 @@ def inference(self, label, inst):
def sample_features(self, inst):
# read precomputed feature clusters
cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path)
features_clustered = np.load(cluster_path).item()
features_clustered = np.load(cluster_path, encoding='latin1').item()

# randomly sample from the feature clusters
inst_np = inst.cpu().numpy().astype(int)
Expand Down
1 change: 1 addition & 0 deletions options/test_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def initialize(self):
self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run')
self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features')
self.parser.add_argument('--use_encoded_image', action='store_true', help='if specified, encode the real image to get the feature map')
self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file")
self.parser.add_argument("--engine", type=str, help="run serialized TRT engine")
self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT")
Expand Down
2 changes: 1 addition & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
elif opt.onnx:
generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']])
else:
generated = model.inference(data['label'], data['inst'])
generated = model.inference(data['label'], data['inst'], data['image'])

visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
('synthesized_image', util.tensor2im(generated.data[0]))])
Expand Down

0 comments on commit e50d03e

Please sign in to comment.