How to implement beam search in JAX/Flax? #917
Answered
by
marcvanzee
marcvanzee
asked this question in
Q&A
-
Beta Was this translation helpful? Give feedback.
Answered by
marcvanzee
Jan 21, 2021
Replies: 1 comment
-
We currently have implemented this in our WMT example in decode.py. It presumes a function that takes integer tokens and returns logits. It has been optimized for TPUs. |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
marcvanzee
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
We currently have implemented this in our WMT example in decode.py. It presumes a function that takes integer tokens and returns logits. It has been optimized for TPUs.