Skip to content

Commit

Permalink
fixed error in beam decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
nasib-ullah authored Mar 24, 2022
1 parent 42afa4e commit 37c8844
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions models/RecNet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,9 @@ def BeamDecoding(self,feats, width, alpha=0.,max_caption_len = 15):
rfunc = np.vectorize(lambda t: '' if t == 'EOS' else t) # to transform EOS to null string
lfunc = np.vectorize(lambda t: '' if t == 'SOS' else t) # to transform SOS to null string
pfunc = np.vectorize(lambda t: '' if t == 'PAD' else t) # to transform PAD to null string

if self.cfg.opt_encoder:
feats = self.encoder(feats)

hidden = torch.zeros(self.cfg.n_layers, batch_size, self.cfg.decoder_hidden_size).to(self.device)
if self.cfg.decoder_type == 'lstm':
Expand Down

0 comments on commit 37c8844

Please sign in to comment.