Skip to content

Commit

Permalink
Fix off-by-one errors in generation code and token streaming callback.
Browse files Browse the repository at this point in the history
In the generation code we were feeding the last token of the prompt
twice through the transformer. The new version fixes that and also
works in the case where Prefill is completely disabled.
  • Loading branch information
szabadka committed Apr 4, 2024
1 parent ede337f commit 71ead04
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
17 changes: 13 additions & 4 deletions gemma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -666,12 +666,16 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,

size_t pos_gen_start = pos_offset;
int token = prompt.at(pos_offset);
stream_token(token, 0);
for (size_t generate_pos = 0;
pos < max_tokens && generate_pos < max_generated_tokens;
++pos, ++pos_offset, ++generate_pos) {
Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool);
float* final_activation = activations.x.data();
if (pos_offset >= prompt_size) {
// The condition below is always true if we are doing Prefill above.
// We keep it here for clarity so that the code is correct even if Prefill
// is disabled.
if (pos_offset >= prompt_size - 1) {
PROFILER_ZONE("Gen.Embedding");
// Generation phase
MatVec<kVocabSize, TConfig::kModelDim>(
Expand All @@ -681,9 +685,14 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
Softmax(activations.logits.data(), kVocabSize);
token = SampleTopK<TConfig::kTopK>(activations.logits.data(), kVocabSize,
gen, temperature, accept_token);
}
if (!stream_token(token, activations.logits[token])) {
token = EOS_ID;
if (!stream_token(token, activations.logits[token])) {
token = EOS_ID;
}
} else {
// We would take this branch if we were not doing Prefill but would
// process the tokens of the prompt one at a time.
token = prompt.at(pos_offset + 1);
stream_token(token, 0);
}
if (token == EOS_ID) {
if (verbosity >= 2) {
Expand Down
3 changes: 2 additions & 1 deletion run.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
verbosity](int token, float) {
++abs_pos;
++current_pos;
if (current_pos < prompt_size) {
// <= since position is incremented before
if (current_pos <= prompt_size) {
std::cerr << "." << std::flush;
} else if (token == gcpp::EOS_ID) {
if (!args.multiturn) {
Expand Down

0 comments on commit 71ead04

Please sign in to comment.