From 0c9d32589f068932019f508fe64f1ddc0e166ac0 Mon Sep 17 00:00:00 2001 From: meenchen Date: Tue, 23 May 2023 19:52:02 +0800 Subject: [PATCH] demo --- experimental/transformer/Makefile | 9 ++- experimental/transformer/application/demo.cc | 63 +++++++++++++++++++ .../transformer/include/OPTGenerate.h | 6 +- experimental/transformer/include/profiler.h | 10 ++- experimental/transformer/include/utils.h | 4 ++ experimental/transformer/src/OPTGenerate.cc | 15 ++++- 6 files changed, 96 insertions(+), 11 deletions(-) create mode 100644 experimental/transformer/application/demo.cc diff --git a/experimental/transformer/Makefile b/experimental/transformer/Makefile index fd31c424..007d336a 100644 --- a/experimental/transformer/Makefile +++ b/experimental/transformer/Makefile @@ -3,7 +3,7 @@ CXX = g++ CXXFLAGS = -std=c++17 -mavx2 -pthread -O3 # Executable and source files -TARGET = test_ops test_Int8OPTAttention test_Int8OPTDecoderLayer test_Int8OPTDecoder test_OPTForCausalLM profile_OPTForCausalLM test_OPTTokenizer test_OPTGenerate +TARGET = test_ops test_Int8OPTAttention test_Int8OPTDecoderLayer test_Int8OPTDecoder test_OPTForCausalLM profile_OPTForCausalLM test_OPTTokenizer test_OPTGenerate demo LIB_DIR = ../matmul_optimization/src LIB_SRC = $(wildcard $(LIB_DIR)/lib/*.cc) @@ -40,10 +40,13 @@ profile_OPTForCausalLM: $(CXX) $(CXXFLAGS) $(INCLUDE_DIRS) -D PROFILER -o profile_OPTForCausalLM tests/test_OPTForCausalLM.cc $(SRC) test_OPTTokenizer: - $(CXX) $(CXXFLAGS) $(INCLUDE_DIRS) -D PROFILER -o test_OPTTokenizer tests/test_OPTTokenizer.cc $(SRC) + $(CXX) $(CXXFLAGS) $(INCLUDE_DIRS) -o test_OPTTokenizer tests/test_OPTTokenizer.cc $(SRC) test_OPTGenerate: - $(CXX) $(CXXFLAGS) $(INCLUDE_DIRS) -D PROFILER -o test_OPTGenerate tests/test_OPTGenerate.cc $(SRC) + $(CXX) $(CXXFLAGS) $(INCLUDE_DIRS) -o test_OPTGenerate tests/test_OPTGenerate.cc $(SRC) + +demo: + $(CXX) $(CXXFLAGS) $(INCLUDE_DIRS) -o demo application/demo.cc $(SRC) # Clean up clean: diff --git a/experimental/transformer/application/demo.cc b/experimental/transformer/application/demo.cc new file mode 100644 index 00000000..48684f44 --- /dev/null +++ b/experimental/transformer/application/demo.cc @@ -0,0 +1,63 @@ +#include +#include + +#include "OPTGenerate.h" + +std::map model_config = { + {"OPT125M", OPT_125M}, + {"OPT1.3B", OPT_1_3B}, + {"OPT6.7B", OPT_6_7B}, +}; + +std::map model_path = { + {OPT_125M, "models/OPT_125m"}, + {OPT_1_3B, "models/OPT_1.3B"}, + {OPT_6_7B, "models/OPT_6.7B"}, +}; + +int main(int argc, char* argv[]) { + std::string target_model = "OPT1.3B"; + + if (argc > 1) { + auto target_str = argv[1]; + if (model_config.count(target_model) == 0) { + std::cerr << "Model config:" << target_str << " unsupported" << std::endl; + std::cerr << "Please select one of the following:"; + for (const auto& k : model_config) { + std::cerr << k.first << ", "; + } + std::cerr << std::endl; + throw("Unsupported model\n"); + } + std::cout << "Model: " << argv[1] << " selected" << std::endl; + target_model = argv[1]; + } else { + std::cout << "Using default model: " + target_model << std::endl; + } + + // Load model + std::cout << "Loading model... " << std::flush; + int model_id = model_config[target_model]; + std::string m_path = model_path[model_id]; + OPTForCausalLM model = OPTForCausalLM(m_path, get_opt_model_config(model_id)); + std::cout << "Finished!" << std::endl; + + // Load encoder + std::string vocab_file = "./models/OPT_125m/vocab.json"; + std::string bpe_file = "./models/OPT_125m/merges.txt"; + Encoder encoder = get_encoder(vocab_file, bpe_file); + + // Get input from the user + std::cout << "Please enter a line of text: "; + std::string input; + std::getline(std::cin, input); + std::vector input_ids = encoder.encode(input); + std::string decoded = encoder.decode(input_ids); + std::cout << "input:" << decoded << std::endl; + + struct opt_params generation_config; + generation_config.n_predict = 256; + std::vector generated_ids = OPTGenerate(model, input_ids, generation_config, &encoder, true); + + decoded = encoder.decode(generated_ids); +}; diff --git a/experimental/transformer/include/OPTGenerate.h b/experimental/transformer/include/OPTGenerate.h index bafac137..74d31769 100644 --- a/experimental/transformer/include/OPTGenerate.h +++ b/experimental/transformer/include/OPTGenerate.h @@ -81,7 +81,5 @@ void OPT_sample_typical(OPT_token_data_array* candidates, float p, size_t min_ke void OPT_sample_top_p(OPT_token_data_array* candidates, float p, size_t min_keep); std::vector OPTGenerate(OPTForCausalLM model, std::vector input_ids, - const struct opt_params generation_config); - -void OPTGenerate_interactive(OPTForCausalLM model, std::vector input_ids, - const struct opt_params generation_config, Encoder encoder); + const struct opt_params generation_config, Encoder* encoder = NULL, + bool interactive = false); diff --git a/experimental/transformer/include/profiler.h b/experimental/transformer/include/profiler.h index b29203d2..b5188ef1 100644 --- a/experimental/transformer/include/profiler.h +++ b/experimental/transformer/include/profiler.h @@ -34,8 +34,7 @@ class Profiler { counts[section]++; } - void report() const { -#ifdef PROFILER + void report_internal() const { std::cout << "Section, Total time(us), Average time(us), Count, GOPs" << std::endl; for (const auto& entry : durations) { std::string row; @@ -43,7 +42,7 @@ class Profiler { row += std::to_string(entry.second) + ", "; row += std::to_string(entry.second / counts.at(entry.first)) + ", "; if (flops.count(entry.first) == 0) - row += std::to_string(counts.at(entry.first)); + row += std::to_string(counts.at(entry.first)) + ", N/A, N/A"; else { row += std::to_string(counts.at(entry.first)) + ", "; // ops and microsecond @@ -51,6 +50,11 @@ class Profiler { } std::cout << row << std::endl; } + } + + void report() const { +#ifdef PROFILER + report_internal(); #endif } diff --git a/experimental/transformer/include/utils.h b/experimental/transformer/include/utils.h index 1f94e9d2..c64d3e33 100644 --- a/experimental/transformer/include/utils.h +++ b/experimental/transformer/include/utils.h @@ -9,6 +9,10 @@ #include "profiler.h" +#define STATS_START(x) Profiler::getInstance().start(x) +#define STATS_FLOPS(x, y) Profiler::getInstance().start(x, y) +#define STATS_END(x) Profiler::getInstance().stop(x) + #ifdef PROFILER #define PROFILE_START(x) Profiler::getInstance().start(x) #define PROFILE_START_FLOPS(x, y) Profiler::getInstance().start(x, y) diff --git a/experimental/transformer/src/OPTGenerate.cc b/experimental/transformer/src/OPTGenerate.cc index a9cfa6fe..43d3a997 100644 --- a/experimental/transformer/src/OPTGenerate.cc +++ b/experimental/transformer/src/OPTGenerate.cc @@ -1,6 +1,7 @@ #include "OPTGenerate.h" #include "common.h" +#include "util.h" void OPT_sample_repetition_penalty(OPT_token_data_array* candidates, const int* last_tokens, size_t last_tokens_size, float penalty) { @@ -320,7 +321,7 @@ void OPT_sample_top_p(OPT_token_data_array* candidates, float p, size_t min_keep // OPTGenerate function std::vector OPTGenerate(OPTForCausalLM model, std::vector input_ids, - const struct opt_params generation_config) { + const struct opt_params generation_config, Encoder* encoder, bool interactive) { std::vector last_n_tokens(generation_config.n_ctx); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); std::vector embd; @@ -338,10 +339,14 @@ std::vector OPTGenerate(OPTForCausalLM model, std::vector input_ids, } } + if (encoder == NULL) interactive = false; + if (interactive) std::cout << "Generated: " << std::endl; + bool has_past_kv = false; std::vector> past_keys, past_values; int n_remain = generation_config.n_predict; while (n_remain != 0) { + STATS_START("Token generation"); std::vector logits(generation_config.n_vocab); int sqlen = 1; @@ -433,8 +438,16 @@ std::vector OPTGenerate(OPTForCausalLM model, std::vector input_ids, generate_ids.push_back(id); input_ids = std::vector{id}; + if (interactive) std::cout << encoder->decode(input_ids) << std::flush; + --n_remain; + STATS_END("Token generation"); } + if (interactive) std::cout << std::endl; + + Profiler::getInstance().report_internal(); + Profiler::getInstance().reset(); + return generate_ids; }