Skip to content

Commit

Permalink
add estimator in greedy inference (#135)
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s authored Oct 25, 2024
1 parent d7959ba commit 5bf241b
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 1 deletion.
2 changes: 2 additions & 0 deletions eole/predict/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def predict_batch(self, batch, attn_debug):
top_p=self.top_p,
beam_size=self.beam_size,
ban_unk_token=self.ban_unk_token,
add_estimator=self.add_estimator,
)
else:
# TODO: support these blacklisted features
Expand All @@ -66,6 +67,7 @@ def predict_batch(self, batch, attn_debug):
stepwise_penalty=self.stepwise_penalty,
ratio=self.ratio,
ban_unk_token=self.ban_unk_token,
add_estimator=self.add_estimator,
)
return self._predict_batch_with_strategy(batch, decode_strategy)

Expand Down
1 change: 1 addition & 0 deletions eole/predict/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def predict_batch(self, batch, attn_debug, scoring=False):
top_p=self.top_p,
beam_size=self.beam_size,
ban_unk_token=self.ban_unk_token,
add_estimator=self.add_estimator,
)
else:
# TODO: support these blacklisted features
Expand Down
6 changes: 5 additions & 1 deletion eole/predict/greedy_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ def __init__(
self.topk_scores = None
self.beam_size = beam_size
self.n_best = n_best
if add_estimator:
self.num_hyp = self.beam_size
else:
self.num_hyp = self.n_best

def initialize(self, enc_out, src_len, device=None, target_prefix=None):
"""Initialize for decoding."""
Expand Down Expand Up @@ -282,7 +286,7 @@ def update_finished(self):
if self.done:
for b in range(self.batch_size):
best_hyp = sorted(self.hypotheses[b], key=lambda x: x[0], reverse=True)[
: self.n_best
: self.num_hyp
]
for score, pred, attn in best_hyp:
self.scores[b].append(score)
Expand Down
1 change: 1 addition & 0 deletions eole/predict/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def predict_batch(self, batch, attn_debug):
top_p=self.top_p,
beam_size=self.beam_size,
ban_unk_token=self.ban_unk_token,
add_estimator=self.add_estimator,
)
else:
# TODO: support these blacklisted features
Expand Down

0 comments on commit 5bf241b

Please sign in to comment.