From db54093260dd3bafa9900b29666fcfec9810b856 Mon Sep 17 00:00:00 2001 From: meenchen Date: Tue, 23 May 2023 18:31:38 +0800 Subject: [PATCH] hook up the opt model inference interface into text generate --- .../transformer/include/OPTGenerate.h | 61 +++--- experimental/transformer/src/OPTGenerate.cc | 176 +++++++++--------- .../transformer/tests/test_OPTGenerate.cc | 50 ++--- 3 files changed, 141 insertions(+), 146 deletions(-) diff --git a/experimental/transformer/include/OPTGenerate.h b/experimental/transformer/include/OPTGenerate.h index 59ddba4c..bafac137 100644 --- a/experimental/transformer/include/OPTGenerate.h +++ b/experimental/transformer/include/OPTGenerate.h @@ -1,14 +1,15 @@ +#include +#include #include -#include -#include +#include #include -#include -#include #include -#include -#include +#include +#include +#include #include "OPTForCausalLM.h" +#include "OPTTokenizer.h" #include "operators.h" #include "utils.h" @@ -27,29 +28,29 @@ typedef struct OPT_token_data_array { } OPT_token_data_array; struct opt_params { - int32_t seed = -1; // RNG seed - int32_t n_threads = 1; // TODO: fix this - int32_t n_predict = 128; // new tokens to predict - int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions) - int32_t n_ctx = 512; // context size - int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_vocab = 50272; // vocabulary size + int32_t seed = -1; // RNG seed + int32_t n_threads = 1; // TODO: fix this + int32_t n_predict = 128; // new tokens to predict + int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions) + int32_t n_ctx = 512; // context size + int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_vocab = 50272; // vocabulary size // sampling parameters - std::unordered_map logit_bias; // logit bias for specific tokens - int32_t top_k = 40; // <= 0 to use vocab size - float top_p = 0.95f; // 1.0 = disabled - float tfs_z = 1.00f; // 1.0 = disabled - float typical_p = 1.00f; // 1.0 = disabled - float temp = 0.80f; // 1.0 = disabled - float repeat_penalty = 1.10f; // 1.0 = disabled - int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float frequency_penalty = 0.00f; // 0.0 = disabled - float presence_penalty = 0.00f; // 0.0 = disabled - int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float mirostat_tau = 5.00f; // target entropy - float mirostat_eta = 0.10f; // learning rate + std::unordered_map logit_bias; // logit bias for specific tokens + int32_t top_k = 40; // <= 0 to use vocab size + float top_p = 0.95f; // 1.0 = disabled + float tfs_z = 1.00f; // 1.0 = disabled + float typical_p = 1.00f; // 1.0 = disabled + float temp = 0.80f; // 1.0 = disabled + float repeat_penalty = 1.10f; // 1.0 = disabled + int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float frequency_penalty = 0.00f; // 0.0 = disabled + float presence_penalty = 0.00f; // 0.0 = disabled + int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau = 5.00f; // target entropy + float mirostat_eta = 0.10f; // learning rate }; void OPT_sample_repetition_penalty(OPT_token_data_array* candidates, const int* last_tokens, size_t last_tokens_size, @@ -79,6 +80,8 @@ void OPT_sample_typical(OPT_token_data_array* candidates, float p, size_t min_ke void OPT_sample_top_p(OPT_token_data_array* candidates, float p, size_t min_keep); +std::vector OPTGenerate(OPTForCausalLM model, std::vector input_ids, + const struct opt_params generation_config); -std::vector OPTGenerate(std::vector input_ids, - const struct opt_params generation_config); +void OPTGenerate_interactive(OPTForCausalLM model, std::vector input_ids, + const struct opt_params generation_config, Encoder encoder); diff --git a/experimental/transformer/src/OPTGenerate.cc b/experimental/transformer/src/OPTGenerate.cc index 6b5a6a86..a9cfa6fe 100644 --- a/experimental/transformer/src/OPTGenerate.cc +++ b/experimental/transformer/src/OPTGenerate.cc @@ -1,7 +1,9 @@ #include "OPTGenerate.h" -void OPT_sample_repetition_penalty(OPT_token_data_array * candidates, const int * last_tokens, - size_t last_tokens_size, float penalty) { +#include "common.h" + +void OPT_sample_repetition_penalty(OPT_token_data_array* candidates, const int* last_tokens, size_t last_tokens_size, + float penalty) { if (last_tokens_size == 0 || penalty == 1.0f) { return; } @@ -14,8 +16,7 @@ void OPT_sample_repetition_penalty(OPT_token_data_array * candidates, const int if (candidates->data[i].logit <= 0) { candidates->data[i].logit *= penalty; - } - else { + } else { candidates->data[i].logit /= penalty; } } @@ -23,8 +24,8 @@ void OPT_sample_repetition_penalty(OPT_token_data_array * candidates, const int candidates->sorted = false; } -void OPT_sample_frequency_and_presence_penalties(OPT_token_data_array * candidates, const int * last_tokens_p, - size_t last_tokens_size, float alpha_frequency, float alpha_presence) { +void OPT_sample_frequency_and_presence_penalties(OPT_token_data_array* candidates, const int* last_tokens_p, + size_t last_tokens_size, float alpha_frequency, float alpha_presence) { if (last_tokens_size == 0 || (alpha_frequency == 0.0f && alpha_presence == 0.0f)) { return; } @@ -49,34 +50,32 @@ void OPT_sample_frequency_and_presence_penalties(OPT_token_data_array * candidat candidates->sorted = false; } -int OPT_sample_token_greedy(OPT_token_data_array * candidates) { +int OPT_sample_token_greedy(OPT_token_data_array* candidates) { // Find max element - auto max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const OPT_token_data & a, const OPT_token_data & b) { - return a.logit < b.logit; - }); + auto max_iter = + std::max_element(candidates->data, candidates->data + candidates->size, + [](const OPT_token_data& a, const OPT_token_data& b) { return a.logit < b.logit; }); int result = max_iter->id; return result; } -void OPT_sample_temperature(OPT_token_data_array * candidates_p, float temp) { +void OPT_sample_temperature(OPT_token_data_array* candidates_p, float temp) { for (size_t i = 0; i < candidates_p->size; ++i) { candidates_p->data[i].logit /= temp; } } - // // sampling // -void OPT_sample_softmax(OPT_token_data_array * candidates) { +void OPT_sample_softmax(OPT_token_data_array* candidates) { assert(candidates->size > 0); // Sort the logits in descending order if (!candidates->sorted) { - std::sort(candidates->data, candidates->data + candidates->size, [](const OPT_token_data & a, const OPT_token_data & b) { - return a.logit > b.logit; - }); + std::sort(candidates->data, candidates->data + candidates->size, + [](const OPT_token_data& a, const OPT_token_data& b) { return a.logit > b.logit; }); candidates->sorted = true; } @@ -92,7 +91,7 @@ void OPT_sample_softmax(OPT_token_data_array * candidates) { } } -int OPT_sample_token(OPT_token_data_array * candidates) { +int OPT_sample_token(OPT_token_data_array* candidates) { OPT_sample_softmax(candidates); std::vector probs; @@ -102,26 +101,23 @@ int OPT_sample_token(OPT_token_data_array * candidates) { } std::discrete_distribution<> dist(probs.begin(), probs.end()); - auto & rng = OPT_rng; + auto& rng = OPT_rng; int idx = dist(rng); int result = candidates->data[idx].id; return result; } -void OPT_sample_top_k(OPT_token_data_array * candidates, int k, size_t min_keep) { - k = std::max(k, (int) min_keep); - k = std::min(k, (int) candidates->size); +void OPT_sample_top_k(OPT_token_data_array* candidates, int k, size_t min_keep) { + k = std::max(k, (int)min_keep); + k = std::min(k, (int)candidates->size); // Sort scores in descending order if (!candidates->sorted) { - auto comp = [](const OPT_token_data & a, const OPT_token_data & b) { - return a.logit > b.logit; - }; - if (k == (int) candidates->size) { + auto comp = [](const OPT_token_data& a, const OPT_token_data& b) { return a.logit > b.logit; }; + if (k == (int)candidates->size) { std::sort(candidates->data, candidates->data + candidates->size, comp); - } - else { + } else { std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); } candidates->sorted = true; @@ -130,7 +126,8 @@ void OPT_sample_top_k(OPT_token_data_array * candidates, int k, size_t min_keep) candidates->size = k; } -int OPT_sample_token_mirostat(const int n_vocab, OPT_token_data_array * candidates, float tau, float eta, int m, float * mu) { +int OPT_sample_token_mirostat(const int n_vocab, OPT_token_data_array* candidates, float tau, float eta, int m, + float* mu) { auto N = float(n_vocab); OPT_sample_softmax(candidates); @@ -156,9 +153,9 @@ int OPT_sample_token_mirostat(const int n_vocab, OPT_token_data_array * candidat int X = OPT_sample_token(candidates); // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const OPT_token_data & candidate) { - return candidate.id == X; - })); + size_t X_idx = std::distance(candidates->data, + std::find_if(candidates->data, candidates->data + candidates->size, + [&](const OPT_token_data& candidate) { return candidate.id == X; })); float observed_surprise = -log2f(candidates->data[X_idx].p); float e = observed_surprise - tau; @@ -168,13 +165,13 @@ int OPT_sample_token_mirostat(const int n_vocab, OPT_token_data_array * candidat return X; } -int OPT_sample_token_mirostat_v2(OPT_token_data_array * candidates, float tau, float eta, float * mu) { +int OPT_sample_token_mirostat_v2(OPT_token_data_array* candidates, float tau, float eta, float* mu) { OPT_sample_softmax(candidates); // Truncate the words with surprise values greater than mu - candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const OPT_token_data & candidate) { - return -log2f(candidate.p) > *mu; - })); + candidates->size = std::distance( + candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, + [&](const OPT_token_data& candidate) { return -log2f(candidate.p) > *mu; })); // Normalize the probabilities of the remaining words OPT_sample_softmax(candidates); @@ -183,9 +180,9 @@ int OPT_sample_token_mirostat_v2(OPT_token_data_array * candidates, float tau, f int X = OPT_sample_token(candidates); // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const OPT_token_data & candidate) { - return candidate.id == X; - })); + size_t X_idx = std::distance(candidates->data, + std::find_if(candidates->data, candidates->data + candidates->size, + [&](const OPT_token_data& candidate) { return candidate.id == X; })); float observed_surprise = -log2f(candidates->data[X_idx].p); float e = observed_surprise - tau; @@ -195,7 +192,7 @@ int OPT_sample_token_mirostat_v2(OPT_token_data_array * candidates, float tau, f return X; } -void OPT_sample_tail_free(OPT_token_data_array * candidates, float z, size_t min_keep) { +void OPT_sample_tail_free(OPT_token_data_array* candidates, float z, size_t min_keep) { if (z >= 1.0f || candidates->size <= 2) { return; } @@ -220,7 +217,7 @@ void OPT_sample_tail_free(OPT_token_data_array * candidates, float z, size_t min // Normalize the second derivatives float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f); - for (float & value : second_derivatives) { + for (float& value : second_derivatives) { value /= second_derivatives_sum; } @@ -240,7 +237,7 @@ void OPT_sample_tail_free(OPT_token_data_array * candidates, float z, size_t min candidates->size = last_idx; } -void OPT_sample_typical(OPT_token_data_array * candidates, float p, size_t min_keep) { +void OPT_sample_typical(OPT_token_data_array* candidates, float p, size_t min_keep) { // Reference implementation: // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr if (p >= 1.0f) { @@ -266,9 +263,8 @@ void OPT_sample_typical(OPT_token_data_array * candidates, float p, size_t min_k std::vector indices(candidates->size); std::iota(indices.begin(), indices.end(), 0); - std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) { - return shifted_scores[a] < shifted_scores[b]; - }); + std::sort(indices.begin(), indices.end(), + [&](size_t a, size_t b) { return shifted_scores[a] < shifted_scores[b]; }); // Compute the cumulative probabilities float cum_sum = 0.0f; @@ -297,7 +293,7 @@ void OPT_sample_typical(OPT_token_data_array * candidates, float p, size_t min_k candidates->size = new_candidates.size(); } -void OPT_sample_top_p(OPT_token_data_array * candidates, float p, size_t min_keep) { +void OPT_sample_top_p(OPT_token_data_array* candidates, float p, size_t min_keep) { if (p >= 1.0f) { return; } @@ -323,51 +319,65 @@ void OPT_sample_top_p(OPT_token_data_array * candidates, float p, size_t min_kee } // OPTGenerate function -std::vector OPTGenerate(std::vector input_ids, - const struct opt_params generation_config) { +std::vector OPTGenerate(OPTForCausalLM model, std::vector input_ids, + const struct opt_params generation_config) { std::vector last_n_tokens(generation_config.n_ctx); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); std::vector embd; std::vector generate_ids; int n_consumed = 0; - while ((int) input_ids.size() > n_consumed) { + while ((int)input_ids.size() > n_consumed) { embd.push_back(input_ids[n_consumed]); last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(input_ids[n_consumed]); ++n_consumed; - if ((int) embd.size() >= generation_config.n_batch) { + if ((int)embd.size() >= generation_config.n_batch) { break; } } + bool has_past_kv = false; + std::vector> past_keys, past_values; int n_remain = generation_config.n_predict; while (n_remain != 0) { - // Predict and evaluate tokens - // TODO: Enable OPT calculation here - // "input_ids" is the vector of all input token ids. - // "embd" is the vector of all context token ids (including generated token ids). - // "generate_ids" is the vector of all generated token ids (input_ids are not included). - //std::vector logits = xxxx - std::vector logits; // TODO: Remove this line after enabling OPT calculation - logits.reserve(generation_config.n_vocab); // TODO: Remove this line after enabling OPT calculation + std::vector logits(generation_config.n_vocab); + + int sqlen = 1; + struct OPTForCausalLM_output model_output; + if (has_past_kv) { + Matrix3D input_ids_mat(input_ids.data(), 1, 1, sqlen); + struct OPTForCausalLM_input model_input = {input_ids_mat, past_keys, past_values}; + model_output = model.forward(model_input); + } else { + sqlen = input_ids.size(); + Matrix3D input_ids_mat(input_ids.data(), 1, 1, sqlen); + struct OPTForCausalLM_input model_input = {input_ids_mat}; + model_output = model.forward(model_input); + } + past_keys = model_output.past_keys; + past_values = model_output.past_values; + has_past_kv = true; + // memcpy model_ouput.logits[-1] to logits + memcpy(logits.data(), &model_output.logits.m_data[(sqlen - 1) * generation_config.n_vocab], + generation_config.n_vocab * sizeof(float)); // Generate - const int n_ctx = generation_config.n_ctx; - const float temp = generation_config.temp; - const int32_t top_k = generation_config.top_k <= 0 ? generation_config.n_vocab : generation_config.top_k; - const float top_p = generation_config.top_p; - const float tfs_z = generation_config.tfs_z; - const float typical_p = generation_config.typical_p; - const int32_t repeat_last_n = generation_config.repeat_last_n < 0 ? n_ctx : generation_config.repeat_last_n; - const float repeat_penalty = generation_config.repeat_penalty; - const float alpha_presence = generation_config.presence_penalty; - const float alpha_frequency = generation_config.frequency_penalty; - const int mirostat = generation_config.mirostat; - const float mirostat_tau = generation_config.mirostat_tau; - const float mirostat_eta = generation_config.mirostat_eta; - const int n_vocab = generation_config.n_vocab; + const int n_ctx = generation_config.n_ctx; + const float temp = generation_config.temp; + const int32_t top_k = generation_config.top_k <= 0 ? generation_config.n_vocab : generation_config.top_k; + const float top_p = generation_config.top_p; + const float tfs_z = generation_config.tfs_z; + const float typical_p = generation_config.typical_p; + const int32_t repeat_last_n = generation_config.repeat_last_n < 0 ? n_ctx : generation_config.repeat_last_n; + const float repeat_penalty = generation_config.repeat_penalty; + const float alpha_presence = generation_config.presence_penalty; + const float alpha_frequency = generation_config.frequency_penalty; + const int mirostat = generation_config.mirostat; + const float mirostat_tau = generation_config.mirostat_tau; + const float mirostat_eta = generation_config.mirostat_eta; + const int n_vocab = generation_config.n_vocab; // Apply generation_config.logit_bias map /* // TODO: Enable logit_bias here @@ -382,34 +392,31 @@ std::vector OPTGenerate(std::vector input_ids, candidates.emplace_back(OPT_token_data{token_id, logits[token_id], 0.0f}); } - OPT_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + OPT_token_data_array candidates_p = {candidates.data(), candidates.size(), false}; // Apply penalties auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); - OPT_sample_repetition_penalty(&candidates_p, - last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, - last_n_repeat, repeat_penalty); + OPT_sample_repetition_penalty(&candidates_p, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, + last_n_repeat, repeat_penalty); OPT_sample_frequency_and_presence_penalties(&candidates_p, - last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, - last_n_repeat, alpha_frequency, alpha_presence); + last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, + last_n_repeat, alpha_frequency, alpha_presence); int id = 0; if (temp <= 0) { id = OPT_sample_token_greedy(&candidates_p); - } - else { + } else { if (mirostat == 1) { static float mirostat_mu = 2.0f * mirostat_tau; const int mirostat_m = 100; OPT_sample_temperature(&candidates_p, temp); - id = OPT_sample_token_mirostat(n_vocab, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); - } - else if (mirostat == 2) { + id = OPT_sample_token_mirostat(n_vocab, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, + &mirostat_mu); + } else if (mirostat == 2) { static float mirostat_mu = 2.0f * mirostat_tau; OPT_sample_temperature(&candidates_p, temp); id = OPT_sample_token_mirostat_v2(&candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); - } - else { + } else { // Temperature sampling OPT_sample_top_k(&candidates_p, top_k, 1); OPT_sample_tail_free(&candidates_p, tfs_z, 1); @@ -424,6 +431,7 @@ std::vector OPTGenerate(std::vector input_ids, last_n_tokens.push_back(id); embd.push_back(id); generate_ids.push_back(id); + input_ids = std::vector{id}; --n_remain; } diff --git a/experimental/transformer/tests/test_OPTGenerate.cc b/experimental/transformer/tests/test_OPTGenerate.cc index 21d75458..b691f3ae 100644 --- a/experimental/transformer/tests/test_OPTGenerate.cc +++ b/experimental/transformer/tests/test_OPTGenerate.cc @@ -1,39 +1,23 @@ #include -#include "OPTGenerate.h" -void test_OPTGenerate() { - // std::cout << "Test End!" << std::endl; -} +#include "OPTGenerate.h" int main() { - // test_OPTGenerate(); - std::vector input_ids = {37500, 10, 998, 64, 28, 626, 11, 158, 2007, 2402, 4, 152, 1579, 16, - 13, 937, 82, 6, 98, 52, 6876, 51, 218, 75, 33, 3280, 14198, 4}; + // std::vector input_ids = {37500, 10, 998, 64, 28, 626, 11, 158, 2007, 2402, 4, 152, 1579, 16, + // 13, 937, 82, 6, 98, 52, 6876, 51, 218, 75, 33, 3280, 14198, 4}; + std::string vocab_file = "./models/OPT_125m/vocab.json"; + std::string bpe_file = "./models/OPT_125m/merges.txt"; + + Encoder encoder = get_encoder(vocab_file, bpe_file); + std::vector input_ids = encoder.encode("John went to MIT and study Computer Science."); + + std::string decoded = encoder.decode(input_ids); + std::cout << "input:" << decoded << std::endl; + + OPTForCausalLM model = OPTForCausalLM("models/OPT_125m", get_opt_model_config(OPT_125M)); const struct opt_params generation_config; - std::vector generated_ids = OPTGenerate(input_ids, generation_config); - //std::vector generated_ids_answer = {} // TODO: add answer - - /* - std::cout << "Generated: "; - for (int i = 0; i < generated_ids.size(); i++) { - std::cout << generated_ids[i] << " "; - } - std::cout << std::endl; - */ - - /* - // TODO: add comparison after adding answer - bool is_equal = true; - for (int i = 0; i < generated_ids.size(); i++) { - if (generated_ids[i] != generated_ids_answer[i]) { - is_equal = false; - break; - } - } - if (is_equal) - std::cout << "-------- Test of OPTGenerate: Passed! -------- " << std::endl; - else - std::cout << "-------- Test of OPTGenerate: Failed! -------- " << std::endl; - */ - std::cout << "-------- End of test_OPTGenerate --------" << std::endl; + std::vector generated_ids = OPTGenerate(model, input_ids, generation_config); + + decoded = encoder.decode(generated_ids); + std::cout << "generated:" << decoded << std::endl; };