diff --git a/models/RecNet/model.py b/models/RecNet/model.py index d93eafa..2355391 100644 --- a/models/RecNet/model.py +++ b/models/RecNet/model.py @@ -685,6 +685,9 @@ def BeamDecoding(self,feats, width, alpha=0.,max_caption_len = 15): rfunc = np.vectorize(lambda t: '' if t == 'EOS' else t) # to transform EOS to null string lfunc = np.vectorize(lambda t: '' if t == 'SOS' else t) # to transform SOS to null string pfunc = np.vectorize(lambda t: '' if t == 'PAD' else t) # to transform PAD to null string + + if self.cfg.opt_encoder: + feats = self.encoder(feats) hidden = torch.zeros(self.cfg.n_layers, batch_size, self.cfg.decoder_hidden_size).to(self.device) if self.cfg.decoder_type == 'lstm':