From 60d054e041447e2e275cf2c8dff270c38f452f4b Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Sun, 10 Mar 2024 23:49:25 -0400 Subject: [PATCH 1/2] move arg definitions out of gemma.h to app.h --- examples/hello_world/CMakeLists.txt | 2 +- examples/hello_world/run.cc | 3 + gemma.h | 118 --------------------------- util/app.h | 119 ++++++++++++++++++++++++++++ 4 files changed, 123 insertions(+), 119 deletions(-) diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt index 397b957..9d44f04 100644 --- a/examples/hello_world/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -35,7 +35,7 @@ if (BUILD_MODE STREQUAL "local") else() FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 8c7b2cf61b9794b806de091685dc6739dd3db837) endif() -FetchContent_MakeAvailabl(gemma) +FetchContent_MakeAvailable(gemma) if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE "Release") diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 8ec784f..a994f31 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -21,6 +21,9 @@ // copybara:import_next_line:gemma_cpp #include "util/args.h" // copybara:end +// copybara:import_next_line:gemma_cpp +#include "util/app.h" // LoaderArgs +// copybara:end #include "hwy/contrib/thread_pool/thread_pool.h" std::vector tokenize( diff --git a/gemma.h b/gemma.h index 48fb52b..cdd4873 100644 --- a/gemma.h +++ b/gemma.h @@ -71,124 +71,6 @@ struct RuntimeConfig { int verbosity; }; -struct LoaderArgs : public ArgsBase { - LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } - - static std::string ToLower(const std::string& text) { - std::string result = text; - std::transform(begin(result), end(result), begin(result), - [](unsigned char c) { return std::tolower(c); }); - return result; - } - - gcpp::Model ModelType() const { - const std::string model_type_lc = ToLower(model_type); - if (model_type_lc == "2b-pt" || model_type_lc == "2b-it") { - return gcpp::Model::GEMMA_2B; - } else { - return gcpp::Model::GEMMA_7B; - } - } - - gcpp::ModelTraining ModelTraining() const { - const std::string model_type_lc = ToLower(model_type); - if (model_type_lc == "7b-pt" || model_type_lc == "2b-pt") { - return gcpp::ModelTraining::GEMMA_PT; - } else { - return gcpp::ModelTraining::GEMMA_IT; - } - } - - // Returns error string or nullptr if OK. - const char* Validate() const { - const std::string model_type_lc = ToLower(model_type); - if (model_type.empty()) { - return "Missing --model flag, need to specify either 2b-pt, 7b-pt, " - "2b-it, or 7b-it."; - } - if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" && - model_type_lc != "2b-it" && model_type_lc != "7b-it") { - return "Model type must be 2b-pt, 7b-pt, 2b-it, or " - "7b-it."; - } - if (tokenizer.path.empty()) { - return "Missing --tokenizer flag, a file for the tokenizer is required."; - } - if (compressed_weights.path.empty()) { - return "Missing --compressed_weights flag, a file for the compressed " - "model."; - } - return nullptr; - } - - Path tokenizer; - Path weights; // uncompressed weights file location - Path compressed_weights; // compressed weights file location - std::string model_type; - - template - void ForEach(const Visitor& visitor) { - visitor(tokenizer, "tokenizer", Path(), - "Path name of tokenizer model file.\n Required argument."); - visitor( - compressed_weights, "compressed_weights", Path(), - "Path name of compressed weights file, regenerated from `--weights` " - "file if " - "the compressed weights file does not exist.\n Required argument."); - visitor(model_type, "model", std::string(), - "Model type\n 2b-it = 2B parameters, instruction-tuned\n " - "2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters " - "instruction-tuned\n 7b-pt = 7B parameters, pretrained\n" - " Required argument."); - visitor(weights, "weights", Path(), - "Path name of model weights (.sbs) file. Only required if " - "compressed_weights file is not present and needs to be " - "regenerated. This parameter is only required for compressing" - "new model weight exports, otherwise it is not needed."); - } -}; - -struct InferenceArgs : public ArgsBase { - InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } - - size_t max_tokens; - size_t max_generated_tokens; - - float temperature; - bool deterministic; - bool multiturn; - - // Returns error string or nullptr if OK. - const char* Validate() const { - if (max_tokens > gcpp::kSeqLen) { - return "max_tokens is larger than the maximum sequence length (see " - "configs.h)."; - } - if (max_generated_tokens > max_tokens) { - return "Maximum number of generated tokens is larger than the maximum " - "total tokens."; - } - return nullptr; - } - - template - void ForEach(const Visitor& visitor) { - visitor(max_tokens, "max_tokens", size_t{3072}, - "Maximum number of tokens in prompt + generation."); - visitor(max_generated_tokens, "max_generated_tokens", size_t{2048}, - "Maximum number of tokens to generate."); - - visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2); - visitor(deterministic, "deterministic", false, - "Make top-k sampling deterministic", 2); - visitor(multiturn, "multiturn", false, - "Multiturn mode\n 0 = clear KV cache after every " - "interaction\n 1 = continue KV cache after every interaction\n " - " Default : 0 (conversation " - "resets every turn)"); - } -}; - struct GemmaInterface; struct Gemma { diff --git a/util/app.h b/util/app.h index 754b2fb..8a59e09 100644 --- a/util/app.h +++ b/util/app.h @@ -98,6 +98,125 @@ class AppArgs : public ArgsBase { } }; +struct LoaderArgs : public ArgsBase { + LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + + static std::string ToLower(const std::string& text) { + std::string result = text; + std::transform(begin(result), end(result), begin(result), + [](unsigned char c) { return std::tolower(c); }); + return result; + } + + gcpp::Model ModelType() const { + const std::string model_type_lc = ToLower(model_type); + if (model_type_lc == "2b-pt" || model_type_lc == "2b-it") { + return gcpp::Model::GEMMA_2B; + } else { + return gcpp::Model::GEMMA_7B; + } + } + + gcpp::ModelTraining ModelTraining() const { + const std::string model_type_lc = ToLower(model_type); + if (model_type_lc == "7b-pt" || model_type_lc == "2b-pt") { + return gcpp::ModelTraining::GEMMA_PT; + } else { + return gcpp::ModelTraining::GEMMA_IT; + } + } + + // Returns error string or nullptr if OK. + const char* Validate() const { + const std::string model_type_lc = ToLower(model_type); + if (model_type.empty()) { + return "Missing --model flag, need to specify either 2b-pt, 7b-pt, " + "2b-it, or 7b-it."; + } + if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" && + model_type_lc != "2b-it" && model_type_lc != "7b-it") { + return "Model type must be 2b-pt, 7b-pt, 2b-it, or " + "7b-it."; + } + if (tokenizer.path.empty()) { + return "Missing --tokenizer flag, a file for the tokenizer is required."; + } + if (compressed_weights.path.empty()) { + return "Missing --compressed_weights flag, a file for the compressed " + "model."; + } + return nullptr; + } + + Path tokenizer; + Path weights; // uncompressed weights file location + Path compressed_weights; // compressed weights file location + std::string model_type; + + template + void ForEach(const Visitor& visitor) { + visitor(tokenizer, "tokenizer", Path(), + "Path name of tokenizer model file.\n Required argument."); + visitor( + compressed_weights, "compressed_weights", Path(), + "Path name of compressed weights file, regenerated from `--weights` " + "file if " + "the compressed weights file does not exist.\n Required argument."); + visitor(model_type, "model", std::string(), + "Model type\n 2b-it = 2B parameters, instruction-tuned\n " + "2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters " + "instruction-tuned\n 7b-pt = 7B parameters, pretrained\n" + " Required argument."); + visitor(weights, "weights", Path(), + "Path name of model weights (.sbs) file. Only required if " + "compressed_weights file is not present and needs to be " + "regenerated. This parameter is only required for compressing" + "new model weight exports, otherwise it is not needed."); + } +}; + + +struct InferenceArgs : public ArgsBase { + InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + + size_t max_tokens; + size_t max_generated_tokens; + + float temperature; + bool deterministic; + bool multiturn; + + // Returns error string or nullptr if OK. + const char* Validate() const { + if (max_tokens > gcpp::kSeqLen) { + return "max_tokens is larger than the maximum sequence length (see " + "configs.h)."; + } + if (max_generated_tokens > max_tokens) { + return "Maximum number of generated tokens is larger than the maximum " + "total tokens."; + } + return nullptr; + } + + template + void ForEach(const Visitor& visitor) { + visitor(max_tokens, "max_tokens", size_t{3072}, + "Maximum number of tokens in prompt + generation."); + visitor(max_generated_tokens, "max_generated_tokens", size_t{2048}, + "Maximum number of tokens to generate."); + + visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2); + visitor(deterministic, "deterministic", false, + "Make top-k sampling deterministic", 2); + visitor(multiturn, "multiturn", false, + "Multiturn mode\n 0 = clear KV cache after every " + "interaction\n 1 = continue KV cache after every interaction\n " + " Default : 0 (conversation " + "resets every turn)"); + } +}; + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_ From 72247614bb2a3d9781c4326cd19029f3fe3e43f8 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Tue, 12 Mar 2024 15:10:44 -0400 Subject: [PATCH 2/2] fix prefill feedback off-by-1, update fetch commit hash --- examples/hello_world/CMakeLists.txt | 2 +- run.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt index 9d44f04..292c80c 100644 --- a/examples/hello_world/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -33,7 +33,7 @@ if (BUILD_MODE STREQUAL "local") # Relative path to gemma.cpp from examples/hello_world/build/ FetchContent_Declare(gemma SOURCE_DIR ../../..) else() - FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 8c7b2cf61b9794b806de091685dc6739dd3db837) + FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG a9aa63fd2ea6b786ed0706d619588bfe2d43370e) endif() FetchContent_MakeAvailable(gemma) diff --git a/run.cc b/run.cc index 8bc0910..d6bf22d 100644 --- a/run.cc +++ b/run.cc @@ -119,7 +119,7 @@ void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache, verbosity](int token, float) { ++abs_pos; ++current_pos; - if (current_pos <= prompt_size) { + if (current_pos < prompt_size) { std::cerr << "." << std::flush; } else if (token == gcpp::EOS_ID) { if (!args.multiturn) {