From 9c7ab8473ef0512e09b3f18d4816124a868c32be Mon Sep 17 00:00:00 2001 From: willxxy <90741489+willxxy@users.noreply.github.com> Date: Wed, 17 Jul 2024 03:22:13 -0400 Subject: [PATCH] PixelCNNSampler sample function fix (#148) * fix pixelcnnsampler sample func * ignore tests/dummy_output_dir * Update .gitignore --- .gitignore | 1 + src/pythae/samplers/pixelcnn_sampler/pixelcnn_sampler.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 06ea2c7b..727bfa12 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ examples/notebooks/models_training/my_model/ examples/showcases copy/* examples/notebooks/models_training/generate.py dummy_output_dir/* +tests/dummy_output_dir/ examples/notebooks/models_training/my_model/* examples/notebooks/test.ipynb examples/notebooks/dummy_output_dir/* diff --git a/src/pythae/samplers/pixelcnn_sampler/pixelcnn_sampler.py b/src/pythae/samplers/pixelcnn_sampler/pixelcnn_sampler.py index 9521fa81..f4b6f7d8 100644 --- a/src/pythae/samplers/pixelcnn_sampler/pixelcnn_sampler.py +++ b/src/pythae/samplers/pixelcnn_sampler/pixelcnn_sampler.py @@ -203,7 +203,10 @@ def sample( probs = F.softmax(out[:, :, k, l], dim=-1) z[:, :, k, l] = torch.multinomial(probs, 1).float() - z_quant = self.model.quantizer.embeddings(z.reshape(z.shape[0], -1).long()) + if isinstance(self.model.quantizer.embeddings, torch.Tensor): + z_quant = self.model.quantizer.embeddings[z.reshape(z.shape[0], -1).long()] + else: + z_quant = self.model.quantizer.embeddings(z.reshape(z.shape[0], -1).long()) if self.needs_reshape: z_quant = z_quant.reshape(z.shape[0], -1)