Skip to content

Commit

Permalink
PixelCNNSampler sample function fix (#148)
Browse files Browse the repository at this point in the history
* fix pixelcnnsampler sample func

* ignore tests/dummy_output_dir

* Update .gitignore
  • Loading branch information
willxxy authored Jul 17, 2024
1 parent bb785de commit 9c7ab84
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ examples/notebooks/models_training/my_model/
examples/showcases copy/*
examples/notebooks/models_training/generate.py
dummy_output_dir/*
tests/dummy_output_dir/
examples/notebooks/models_training/my_model/*
examples/notebooks/test.ipynb
examples/notebooks/dummy_output_dir/*
Expand Down
5 changes: 4 additions & 1 deletion src/pythae/samplers/pixelcnn_sampler/pixelcnn_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,10 @@ def sample(
probs = F.softmax(out[:, :, k, l], dim=-1)
z[:, :, k, l] = torch.multinomial(probs, 1).float()

z_quant = self.model.quantizer.embeddings(z.reshape(z.shape[0], -1).long())
if isinstance(self.model.quantizer.embeddings, torch.Tensor):
z_quant = self.model.quantizer.embeddings[z.reshape(z.shape[0], -1).long()]
else:
z_quant = self.model.quantizer.embeddings(z.reshape(z.shape[0], -1).long())

if self.needs_reshape:
z_quant = z_quant.reshape(z.shape[0], -1)
Expand Down

0 comments on commit 9c7ab84

Please sign in to comment.