Skip to content

How to implement beam search in JAX/Flax? #917

Answered by marcvanzee
marcvanzee asked this question in Q&A
Discussion options

You must be logged in to vote

We currently have implemented this in our WMT example in decode.py. It presumes a function that takes integer tokens and returns logits. It has been optimized for TPUs.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant
Converted from issue

This discussion was converted from issue #875 on January 21, 2021 13:55.