diff --git a/experimental/transformer/src/OPTGenerate.cc b/experimental/transformer/src/OPTGenerate.cc index b94cb5fa..6b5a6a86 100644 --- a/experimental/transformer/src/OPTGenerate.cc +++ b/experimental/transformer/src/OPTGenerate.cc @@ -12,11 +12,10 @@ void OPT_sample_repetition_penalty(OPT_token_data_array * candidates, const int continue; } - // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. - // This is common fix for this problem, which is to multiply by the penalty instead of dividing. if (candidates->data[i].logit <= 0) { candidates->data[i].logit *= penalty; - } else { + } + else { candidates->data[i].logit /= penalty; } } @@ -70,7 +69,6 @@ void OPT_sample_temperature(OPT_token_data_array * candidates_p, float temp) { // // sampling // - void OPT_sample_softmax(OPT_token_data_array * candidates) { assert(candidates->size > 0); @@ -122,7 +120,8 @@ void OPT_sample_top_k(OPT_token_data_array * candidates, int k, size_t min_keep) }; if (k == (int) candidates->size) { std::sort(candidates->data, candidates->data + candidates->size, comp); - } else { + } + else { std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); } candidates->sorted = true; @@ -331,7 +330,6 @@ std::vector OPTGenerate(std::vector input_ids, std::vector embd; std::vector generate_ids; - //int n_past = 0; int n_consumed = 0; while ((int) input_ids.size() > n_consumed) { embd.push_back(input_ids[n_consumed]); @@ -397,7 +395,6 @@ std::vector OPTGenerate(std::vector input_ids, int id = 0; if (temp <= 0) { - // Greedy sampling id = OPT_sample_token_greedy(&candidates_p); } else { @@ -422,15 +419,12 @@ std::vector OPTGenerate(std::vector input_ids, id = OPT_sample_token(&candidates_p); } } - // printf("`%d`", candidates_p.size); - // add it to the context last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); embd.push_back(id); generate_ids.push_back(id); - // decrement remaining sampling budget --n_remain; }