From c92a22df8a2b9922cc59f440da63874d3b5d8b30 Mon Sep 17 00:00:00 2001 From: meenchen Date: Tue, 23 May 2023 11:28:40 +0800 Subject: [PATCH] fix tests on macos --- experimental/transformer/src/OPTTokenizer.cc | 62 ++++---- experimental/transformer/src/utils.cc | 12 +- .../tests/test_Int8OPTDecoderLayer.cc | 136 ++---------------- 3 files changed, 50 insertions(+), 160 deletions(-) diff --git a/experimental/transformer/src/OPTTokenizer.cc b/experimental/transformer/src/OPTTokenizer.cc index a717d4ee..d98998a8 100644 --- a/experimental/transformer/src/OPTTokenizer.cc +++ b/experimental/transformer/src/OPTTokenizer.cc @@ -2,7 +2,7 @@ /*std::vector OPT_tokenize(const OPT_vocab & vocab, const std::string & text, bool add_bos) { std::vector res(text.size() + (int) add_bos); - return res; + return res; }*/ /* @@ -44,21 +44,20 @@ std::map PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { }; */ - -/* - * Tokenizer +/* + * Tokenizer */ Encoder::Encoder(std::map encoder, std::vector> bpe_merges) { this->encoder = encoder; - for(auto &it: encoder) { + for (auto &it : encoder) { this->decoder[it.second] = it.first; } this->byte_encoder = bytes_to_unicode(); - for(auto &it: byte_encoder) { + for (auto &it : byte_encoder) { this->byte_decoder[it.second] = it.first; } - for(int i = 0; i < bpe_merges.size(); ++i) { - this->bpe_ranks.insert(std::make_pair(bpe_merges[i], i)); + for (int i = 0; i < bpe_merges.size(); ++i) { + this->bpe_ranks.insert(std::make_pair(bpe_merges[i], i)); } } @@ -93,7 +92,7 @@ Encoder::std::vector> bytes_to_unicode() { std::unordered_map Encoder::bytes_to_unicode() { std::unordered_map byte_to_unicode; - + // Range from '!' to '~' for (int b = '!'; b <= '~'; ++b) { byte_to_unicode[b] = std::string(1, static_cast(b)); @@ -143,7 +142,7 @@ std::unordered_map Encoder::bytes_to_unicode() { byte_to_unicode[0xA3] = u8"\u0156"; // Ŗ byte_to_unicode[0xA4] = u8"\u00A4"; // Currency symbol byte_to_unicode[0xA5] = u8"\u0128"; // Ĩ - + return byte_to_unicode; } @@ -159,43 +158,42 @@ std::set> Encoder::get_pairs(std::vectorcache.find(token); // Find the token in the cache - if (it != this->cache.end()) { // If the token is in the cache + auto it = this->cache.find(token); // Find the token in the cache + if (it != this->cache.end()) { // If the token is in the cache return it->second; } - std::vector word; // word = tuple(token) + std::vector word; // word = tuple(token) for (char c : token) { word.push_back(std::string(1, c)); } std::set> pairs = get_pairs(word); - if(pairs.empty()) - return token; + if (pairs.empty()) return token; - while(true) { + while (true) { std::pair bigram; - int min_index = std::numeric_limits::max(); // Start with the highest possible int value + int min_index = std::numeric_limits::max(); // Start with the highest possible int value - for (const auto &pair: pairs) { - auto it = this->bpe_ranks.find(pair); // Find the pair in the map - if (it != this->bpe_ranks.end()) { // If the pair is in the map - if (it->second < min_index) { // If the current pair's value is less than the min_index + for (const auto &pair : pairs) { + auto it = this->bpe_ranks.find(pair); // Find the pair in the map + if (it != this->bpe_ranks.end()) { // If the pair is in the map + if (it->second < min_index) { // If the current pair's value is less than the min_index min_index = it->second; bigram = pair; } } } - if (min_index == std::numeric_limits::max()) // No pair was found in bpe_ranks + if (min_index == std::numeric_limits::max()) // No pair was found in bpe_ranks break; std::string first = bigram.first; std::string second = bigram.second; std::vector new_word; int i = 0; - while(i < word.size()) { + while (i < word.size()) { auto it = std::find(word.begin() + i, word.end(), first); if (it == word.end()) { new_word.insert(new_word.end(), word.begin() + i, word.end()); @@ -208,8 +206,7 @@ std::string Encoder::bpe(std::string token) { if (word[i] == first && i < word.size() - 1 && word[i + 1] == second) { new_word.push_back(first + second); i += 2; - } - else { + } else { new_word.push_back(word[i]); i += 1; } @@ -232,7 +229,11 @@ std::vector Encoder::encode(std::string text) { std::vector bpe_tokens; // Using Regex to tokenize - std::regex pat = std::regex("'s|'t|'re|'ve|'m|'ll|'d| ?[a-zA-Z]+| ?[0-9]+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"); + // MACOS does not support p{L}\\p{N}, we may need different regex lib + // std::regex pat = std::regex("'s|'t|'re|'ve|'m|'ll|'d| ?[a-zA-Z]+| ?[0-9]+| + // ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"); + std::regex pat = std::regex("'s|'t|'re|'ve|'m|'ll|'d| ?[a-zA-Z]+| ?[0-9]+| ?[^\\s]+|\\s+(?!\\S)|\\s+"); + std::sregex_iterator iter(text.begin(), text.end(), pat); std::sregex_iterator end; @@ -241,7 +242,7 @@ std::vector Encoder::encode(std::string text) { std::string encoded_token; for (char b : token) { - for (auto &it: this->byte_encoder) { + for (auto &it : this->byte_encoder) { if (it.first == int(static_cast(b))) { encoded_token += it.second; break; @@ -265,8 +266,7 @@ std::string Encoder::decode(std::vector tokens) { if (int(this->decoder[token][0]) < '!' || int(this->decoder[token][0]) > '~') { text += " "; i_flag = 2; - } - else { + } else { text += std::string(1, this->decoder[token][0]); } @@ -296,7 +296,9 @@ Encoder get_encoder(std::string vocab_file, std::string bpe_file) { while (std::getline(infile, line)) { std::istringstream iss(line); std::string a, b; - if (!(iss >> a >> b)) { break; } // error + if (!(iss >> a >> b)) { + break; + } // error bpe_merges.push_back({a, b}); } diff --git a/experimental/transformer/src/utils.cc b/experimental/transformer/src/utils.cc index e5e1ca1d..9c04f8ce 100644 --- a/experimental/transformer/src/utils.cc +++ b/experimental/transformer/src/utils.cc @@ -3,16 +3,22 @@ #include #include +#include // for errno #include #include +#include // for strerror #include template void read_to_array(const char* path, T* array, int size) { std::ifstream infile(path, std::ios::binary | std::ios::in); - assert(infile); - infile.read(reinterpret_cast(array), size * sizeof(T)); - infile.close(); + if (infile.fail()) { + std::cout << "Failed to open file: " << strerror(errno) << std::endl; + throw("Expected error..."); + } else { + infile.read(reinterpret_cast(array), size * sizeof(T)); + infile.close(); + } } struct max_error_info { diff --git a/experimental/transformer/tests/test_Int8OPTDecoderLayer.cc b/experimental/transformer/tests/test_Int8OPTDecoderLayer.cc index 9e3ce98a..2b5720c6 100644 --- a/experimental/transformer/tests/test_Int8OPTDecoderLayer.cc +++ b/experimental/transformer/tests/test_Int8OPTDecoderLayer.cc @@ -1,8 +1,8 @@ -#include "Int8OPTDecoderLayer.h" +#include "Int8OPTDecoder.h" #include "operators.h" #include "utils.h" -#define MAX_TEST_MEMORY_BUF 1024 * 1024 * 1024 // 1 GB +#define MAX_TEST_MEMORY_BUF 1 * 1024 * 1024 * 1576 // 1.5 GB static char buffer[MAX_TEST_MEMORY_BUF]; class MemoryAllocator { @@ -232,66 +232,7 @@ void test_DecoderLayer_generate_cache() { const int sqlen = 1, b = 1, past_len = 108, head_dim = embed_dim / num_heads; MemoryAllocator mem_buf; - struct BMM_S8T_S8N_F32T_params qk_bmm; - struct BMM_S8T_S8N_S8T_params pv_bmm; - struct W8A8B8O8Linear_params k_proj, v_proj, q_proj; - Matrix3D k_proj_weight(mem_buf.get_int8buffer(embed_dim * embed_dim), 1, embed_dim, embed_dim); - Matrix3D k_proj_bias(mem_buf.get_int8buffer(embed_dim), 1, 1, embed_dim); - k_proj.weight = k_proj_weight; - k_proj.bias = k_proj_bias; - auto k_proj_op = W8A8B8O8Linear(k_proj); - Matrix3D v_proj_weight(mem_buf.get_int8buffer(embed_dim * embed_dim), 1, embed_dim, embed_dim); - Matrix3D v_proj_bias(mem_buf.get_int8buffer(embed_dim), 1, 1, embed_dim); - v_proj.weight = v_proj_weight; - v_proj.bias = v_proj_bias; - auto v_proj_op = W8A8B8O8Linear(v_proj); - Matrix3D q_proj_weight(mem_buf.get_int8buffer(embed_dim * embed_dim), 1, embed_dim, embed_dim); - Matrix3D q_proj_bias(mem_buf.get_int8buffer(embed_dim), 1, 1, embed_dim); - q_proj.weight = q_proj_weight; - q_proj.bias = q_proj_bias; - auto q_proj_op = W8A8B8O8Linear(q_proj); - - struct W8A8BFP32OFP32Linear_params out_proj; - Matrix3D out_proj_weight(mem_buf.get_int8buffer(embed_dim * embed_dim), 1, embed_dim, embed_dim); - Matrix3D out_proj_bias(mem_buf.get_fpbuffer(embed_dim), 1, 1, embed_dim); - out_proj.weight = out_proj_weight; - out_proj.bias = out_proj_bias; - auto out_proj_op = W8A8BFP32OFP32Linear(out_proj); - - struct LayerNormQ_params self_attn_layer_norm, final_layer_norm; - Matrix3D self_attn_layer_norm_weight(mem_buf.get_fpbuffer(embed_dim), 1, 1, embed_dim); - Matrix3D self_attn_layer_norm_bias(mem_buf.get_fpbuffer(embed_dim), 1, 1, embed_dim); - self_attn_layer_norm.weight = self_attn_layer_norm_weight; - self_attn_layer_norm.bias = self_attn_layer_norm_bias; - - Matrix3D final_layer_norm_weight(mem_buf.get_fpbuffer(embed_dim), 1, 1, embed_dim); - Matrix3D final_layer_norm_bias(mem_buf.get_fpbuffer(embed_dim), 1, 1, embed_dim); - final_layer_norm.weight = final_layer_norm_weight; - final_layer_norm.bias = final_layer_norm_bias; - LayerNormQ self_attn_layer_norm_op = LayerNormQ(self_attn_layer_norm); - LayerNormQ final_layer_norm_op = LayerNormQ(final_layer_norm); - - struct W8A8B8O8LinearReLU_params fc1; - Matrix3D fc1_weight(mem_buf.get_int8buffer(embed_dim * hidden_dim), 1, hidden_dim, embed_dim); - Matrix3D fc1_bias(mem_buf.get_int8buffer(hidden_dim), 1, 1, hidden_dim); - fc1.weight = fc1_weight; - fc1.bias_int8 = fc1_bias; - auto fc1_op = W8A8B8O8LinearReLU(fc1); - - struct W8A8BFP32OFP32Linear_params fc2; - Matrix3D fc2_weight(mem_buf.get_int8buffer(embed_dim * hidden_dim), 1, embed_dim, hidden_dim); - Matrix3D fc2_bias(mem_buf.get_fpbuffer(embed_dim), 1, 1, embed_dim); - fc2.weight = fc2_weight; - fc2.bias = fc2_bias; - auto fc2_op = W8A8BFP32OFP32Linear(fc2); - - auto qk_bmm_op = BMM_S8T_S8N_F32T(qk_bmm); - auto pv_bmm_op = BMM_S8T_S8N_S8T(pv_bmm); - - int layer_idx = 0; - Int8OPTDecoderLayer layer = Int8OPTDecoderLayer( - "models/OPT_125m/decoder/layer0", get_opt_model_config(OPT_125M), layer_idx, self_attn_layer_norm_op, - final_layer_norm_op, fc1_op, fc2_op, qk_bmm_op, pv_bmm_op, k_proj_op, v_proj_op, q_proj_op, out_proj_op); + Int8OPTDecoder decoder = Int8OPTDecoder("models/OPT_125m/decoder/", get_opt_model_config(OPT_125M)); int tgz = sqlen + past_len; Matrix3D hidden_states(mem_buf.get_fpbuffer(b * sqlen * embed_dim), b, sqlen, embed_dim); @@ -305,12 +246,12 @@ void test_DecoderLayer_generate_cache() { struct Int8OPTDecoderLayer_input input = {hidden_states, attention_mask, past_keys, past_value}; - struct Int8OPTDecoderLayer_output output = layer.forward(input); + struct Int8OPTDecoderLayer_output output = decoder.layers[0].forward(input); Matrix3D residualGT(mem_buf.get_fpbuffer(b * sqlen * embed_dim), b, sqlen, embed_dim); read_to_array("assets/tests/OPT_125m/test_cache_residual.bin", residualGT.m_data, b * sqlen * embed_dim); - // print_first_k_elelment("output.hidden_states.m_data", output.hidden_states.m_data, 64); - // print_first_k_elelment("residualGT.m_data", residualGT.m_data, 64); + // // print_first_k_elelment("output.hidden_states.m_data", output.hidden_states.m_data, 64); + // // print_first_k_elelment("residualGT.m_data", residualGT.m_data, 64); int8_t *key_statesGT = mem_buf.get_int8buffer(output.past_key_value.first.length()); read_to_array("assets/tests/OPT_125m/test_present_key.bin", key_statesGT, output.past_key_value.first.length()); int8_t *value_statesGT = mem_buf.get_int8buffer(output.past_key_value.second.length()); @@ -334,66 +275,7 @@ void test_DecoderLayer_generate_cache_1_3B() { const int sqlen = 1, b = 1, past_len = 108, head_dim = embed_dim / num_heads; MemoryAllocator mem_buf; - struct BMM_S8T_S8N_F32T_params qk_bmm; - struct BMM_S8T_S8N_S8T_params pv_bmm; - struct W8A8B8O8Linear_params k_proj, v_proj, q_proj; - Matrix3D k_proj_weight(mem_buf.get_int8buffer(embed_dim * embed_dim), 1, embed_dim, embed_dim); - Matrix3D k_proj_bias(mem_buf.get_int8buffer(embed_dim), 1, 1, embed_dim); - k_proj.weight = k_proj_weight; - k_proj.bias = k_proj_bias; - auto k_proj_op = W8A8B8O8Linear(k_proj); - Matrix3D v_proj_weight(mem_buf.get_int8buffer(embed_dim * embed_dim), 1, embed_dim, embed_dim); - Matrix3D v_proj_bias(mem_buf.get_int8buffer(embed_dim), 1, 1, embed_dim); - v_proj.weight = v_proj_weight; - v_proj.bias = v_proj_bias; - auto v_proj_op = W8A8B8O8Linear(v_proj); - Matrix3D q_proj_weight(mem_buf.get_int8buffer(embed_dim * embed_dim), 1, embed_dim, embed_dim); - Matrix3D q_proj_bias(mem_buf.get_int8buffer(embed_dim), 1, 1, embed_dim); - q_proj.weight = q_proj_weight; - q_proj.bias = q_proj_bias; - auto q_proj_op = W8A8B8O8Linear(q_proj); - - struct W8A8BFP32OFP32Linear_params out_proj; - Matrix3D out_proj_weight(mem_buf.get_int8buffer(embed_dim * embed_dim), 1, embed_dim, embed_dim); - Matrix3D out_proj_bias(mem_buf.get_fpbuffer(embed_dim), 1, 1, embed_dim); - out_proj.weight = out_proj_weight; - out_proj.bias = out_proj_bias; - auto out_proj_op = W8A8BFP32OFP32Linear(out_proj); - - struct LayerNormQ_params self_attn_layer_norm, final_layer_norm; - Matrix3D self_attn_layer_norm_weight(mem_buf.get_fpbuffer(embed_dim), 1, 1, embed_dim); - Matrix3D self_attn_layer_norm_bias(mem_buf.get_fpbuffer(embed_dim), 1, 1, embed_dim); - self_attn_layer_norm.weight = self_attn_layer_norm_weight; - self_attn_layer_norm.bias = self_attn_layer_norm_bias; - - Matrix3D final_layer_norm_weight(mem_buf.get_fpbuffer(embed_dim), 1, 1, embed_dim); - Matrix3D final_layer_norm_bias(mem_buf.get_fpbuffer(embed_dim), 1, 1, embed_dim); - final_layer_norm.weight = final_layer_norm_weight; - final_layer_norm.bias = final_layer_norm_bias; - LayerNormQ self_attn_layer_norm_op = LayerNormQ(self_attn_layer_norm); - LayerNormQ final_layer_norm_op = LayerNormQ(final_layer_norm); - - struct W8A8B8O8LinearReLU_params fc1; - Matrix3D fc1_weight(mem_buf.get_int8buffer(embed_dim * hidden_dim), 1, hidden_dim, embed_dim); - Matrix3D fc1_bias(mem_buf.get_int8buffer(hidden_dim), 1, 1, hidden_dim); - fc1.weight = fc1_weight; - fc1.bias_int8 = fc1_bias; - auto fc1_op = W8A8B8O8LinearReLU(fc1); - - struct W8A8BFP32OFP32Linear_params fc2; - Matrix3D fc2_weight(mem_buf.get_int8buffer(embed_dim * hidden_dim), 1, embed_dim, hidden_dim); - Matrix3D fc2_bias(mem_buf.get_fpbuffer(embed_dim), 1, 1, embed_dim); - fc2.weight = fc2_weight; - fc2.bias = fc2_bias; - auto fc2_op = W8A8BFP32OFP32Linear(fc2); - - auto qk_bmm_op = BMM_S8T_S8N_F32T(qk_bmm); - auto pv_bmm_op = BMM_S8T_S8N_S8T(pv_bmm); - - int layer_idx = 0; - Int8OPTDecoderLayer layer = Int8OPTDecoderLayer( - "models/OPT_1.3B/decoder/layer0", get_opt_model_config(OPT_1_3B), layer_idx, self_attn_layer_norm_op, - final_layer_norm_op, fc1_op, fc2_op, qk_bmm_op, pv_bmm_op, k_proj_op, v_proj_op, q_proj_op, out_proj_op); + Int8OPTDecoder decoder = Int8OPTDecoder("models/OPT_1.3B/decoder/", get_opt_model_config(OPT_1_3B)); int tgz = sqlen + past_len; Matrix3D hidden_states(mem_buf.get_fpbuffer(b * sqlen * embed_dim), b, sqlen, embed_dim); @@ -407,7 +289,7 @@ void test_DecoderLayer_generate_cache_1_3B() { struct Int8OPTDecoderLayer_input input = {hidden_states, attention_mask, past_keys, past_value}; - struct Int8OPTDecoderLayer_output output = layer.forward(input); + struct Int8OPTDecoderLayer_output output = decoder.layers[0].forward(input); Matrix3D residualGT(mem_buf.get_fpbuffer(b * sqlen * embed_dim), b, sqlen, embed_dim); read_to_array("assets/tests/OPT_1.3B/test_cache_residual.bin", residualGT.m_data, b * sqlen * embed_dim); @@ -526,9 +408,9 @@ void test_DecoderLayer() { } int main() { + test_DecoderLayer(); test_DecoderLayer_generate(); test_DecoderLayer_generate_1_3B(); test_DecoderLayer_generate_cache(); test_DecoderLayer_generate_cache_1_3B(); - test_DecoderLayer(); }