Skip to content

Commit

Permalink
Merge pull request #87 from google:refactor-tidy
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 615204427
  • Loading branch information
copybara-github committed Mar 12, 2024
2 parents a9aa63f + 7224761 commit 0221956
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 131 deletions.
23 changes: 12 additions & 11 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,6 @@ cc_library(
],
)

cc_library(
name = "app",
hdrs = [
"util/app.h",
],
deps = [
":args",
"@hwy//:hwy",
],
)

cc_library(
name = "gemma_lib",
srcs = [
Expand All @@ -80,6 +69,18 @@ cc_library(
],
)

cc_library(
name = "app",
hdrs = [
"util/app.h",
],
deps = [
":args",
":gemma_lib",
"@hwy//:hwy",
],
)

cc_binary(
name = "gemma",
srcs = [
Expand Down
2 changes: 1 addition & 1 deletion examples/hello_world/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ if (BUILD_MODE STREQUAL "local")
# Relative path to gemma.cpp from examples/hello_world/build/
FetchContent_Declare(gemma SOURCE_DIR ../../..)
else()
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 8c7b2cf61b9794b806de091685dc6739dd3db837)
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG a9aa63fd2ea6b786ed0706d619588bfe2d43370e)
endif()
FetchContent_MakeAvailable(gemma)

Expand Down
3 changes: 3 additions & 0 deletions examples/hello_world/run.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
// copybara:import_next_line:gemma_cpp
#include "util/args.h"
// copybara:end
// copybara:import_next_line:gemma_cpp
#include "util/app.h" // LoaderArgs
// copybara:end
#include "hwy/contrib/thread_pool/thread_pool.h"

std::vector<int> tokenize(
Expand Down
118 changes: 0 additions & 118 deletions gemma.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,124 +71,6 @@ struct RuntimeConfig {
int verbosity;
};

struct LoaderArgs : public ArgsBase<LoaderArgs> {
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }

static std::string ToLower(const std::string& text) {
std::string result = text;
std::transform(begin(result), end(result), begin(result),
[](unsigned char c) { return std::tolower(c); });
return result;
}

gcpp::Model ModelType() const {
const std::string model_type_lc = ToLower(model_type);
if (model_type_lc == "2b-pt" || model_type_lc == "2b-it") {
return gcpp::Model::GEMMA_2B;
} else {
return gcpp::Model::GEMMA_7B;
}
}

gcpp::ModelTraining ModelTraining() const {
const std::string model_type_lc = ToLower(model_type);
if (model_type_lc == "7b-pt" || model_type_lc == "2b-pt") {
return gcpp::ModelTraining::GEMMA_PT;
} else {
return gcpp::ModelTraining::GEMMA_IT;
}
}

// Returns error string or nullptr if OK.
const char* Validate() const {
const std::string model_type_lc = ToLower(model_type);
if (model_type.empty()) {
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
"2b-it, or 7b-it.";
}
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
return "Model type must be 2b-pt, 7b-pt, 2b-it, or "
"7b-it.";
}
if (tokenizer.path.empty()) {
return "Missing --tokenizer flag, a file for the tokenizer is required.";
}
if (compressed_weights.path.empty()) {
return "Missing --compressed_weights flag, a file for the compressed "
"model.";
}
return nullptr;
}

Path tokenizer;
Path weights; // uncompressed weights file location
Path compressed_weights; // compressed weights file location
std::string model_type;

template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(tokenizer, "tokenizer", Path(),
"Path name of tokenizer model file.\n Required argument.");
visitor(
compressed_weights, "compressed_weights", Path(),
"Path name of compressed weights file, regenerated from `--weights` "
"file if "
"the compressed weights file does not exist.\n Required argument.");
visitor(model_type, "model", std::string(),
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n"
" Required argument.");
visitor(weights, "weights", Path(),
"Path name of model weights (.sbs) file. Only required if "
"compressed_weights file is not present and needs to be "
"regenerated. This parameter is only required for compressing"
"new model weight exports, otherwise it is not needed.");
}
};

struct InferenceArgs : public ArgsBase<InferenceArgs> {
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }

size_t max_tokens;
size_t max_generated_tokens;

float temperature;
bool deterministic;
bool multiturn;

// Returns error string or nullptr if OK.
const char* Validate() const {
if (max_tokens > gcpp::kSeqLen) {
return "max_tokens is larger than the maximum sequence length (see "
"configs.h).";
}
if (max_generated_tokens > max_tokens) {
return "Maximum number of generated tokens is larger than the maximum "
"total tokens.";
}
return nullptr;
}

template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(max_tokens, "max_tokens", size_t{3072},
"Maximum number of tokens in prompt + generation.");
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
"Maximum number of tokens to generate.");

visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
visitor(deterministic, "deterministic", false,
"Make top-k sampling deterministic", 2);
visitor(multiturn, "multiturn", false,
"Multiturn mode\n 0 = clear KV cache after every "
"interaction\n 1 = continue KV cache after every interaction\n "
" Default : 0 (conversation "
"resets every turn)");
}
};

struct GemmaInterface;

struct Gemma {
Expand Down
2 changes: 1 addition & 1 deletion run.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
verbosity](int token, float) {
++abs_pos;
++current_pos;
if (current_pos <= prompt_size) {
if (current_pos < prompt_size) {
std::cerr << "." << std::flush;
} else if (token == gcpp::EOS_ID) {
if (!args.multiturn) {
Expand Down
129 changes: 129 additions & 0 deletions util/app.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,28 @@
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_

#include <iterator>
#if HWY_OS_LINUX
#include <sched.h>

#include <cctype>
#include <cerrno> // IDE does not recognize errno.h as providing errno.
#include <string>
#endif
#include <stddef.h>
#include <stdio.h>

#include <algorithm> // std::clamp
#include <thread> // NOLINT>

// copybara:import_next_line:gemma_cpp
#include "configs.h"
// copybara:end

// copybara:import_next_line:gemma_cpp
#include "gemma.h"
// copybara:end

// copybara:import_next_line:gemma_cpp
#include "util/args.h"
// copybara:end
Expand Down Expand Up @@ -116,6 +127,124 @@ class AppArgs : public ArgsBase<AppArgs> {
}
};

struct LoaderArgs : public ArgsBase<LoaderArgs> {
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }

static std::string ToLower(const std::string& text) {
std::string result = text;
std::transform(begin(result), end(result), begin(result),
[](unsigned char c) { return std::tolower(c); });
return result;
}

gcpp::Model ModelType() const {
const std::string model_type_lc = ToLower(model_type);
if (model_type_lc == "2b-pt" || model_type_lc == "2b-it") {
return gcpp::Model::GEMMA_2B;
} else {
return gcpp::Model::GEMMA_7B;
}
}

gcpp::ModelTraining ModelTraining() const {
const std::string model_type_lc = ToLower(model_type);
if (model_type_lc == "7b-pt" || model_type_lc == "2b-pt") {
return gcpp::ModelTraining::GEMMA_PT;
} else {
return gcpp::ModelTraining::GEMMA_IT;
}
}

// Returns error string or nullptr if OK.
const char* Validate() const {
const std::string model_type_lc = ToLower(model_type);
if (model_type.empty()) {
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
"2b-it, or 7b-it.";
}
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
return "Model type must be 2b-pt, 7b-pt, 2b-it, or "
"7b-it.";
}
if (tokenizer.path.empty()) {
return "Missing --tokenizer flag, a file for the tokenizer is required.";
}
if (compressed_weights.path.empty()) {
return "Missing --compressed_weights flag, a file for the compressed "
"model.";
}
return nullptr;
}

Path tokenizer;
Path weights; // uncompressed weights file location
Path compressed_weights; // compressed weights file location
std::string model_type;

template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(tokenizer, "tokenizer", Path(),
"Path name of tokenizer model file.\n Required argument.");
visitor(
compressed_weights, "compressed_weights", Path(),
"Path name of compressed weights file, regenerated from `--weights` "
"file if "
"the compressed weights file does not exist.\n Required argument.");
visitor(model_type, "model", std::string(),
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n"
" Required argument.");
visitor(weights, "weights", Path(),
"Path name of model weights (.sbs) file. Only required if "
"compressed_weights file is not present and needs to be "
"regenerated. This parameter is only required for compressing"
"new model weight exports, otherwise it is not needed.");
}
};

struct InferenceArgs : public ArgsBase<InferenceArgs> {
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }

size_t max_tokens;
size_t max_generated_tokens;

float temperature;
bool deterministic;
bool multiturn;

// Returns error string or nullptr if OK.
const char* Validate() const {
if (max_tokens > gcpp::kSeqLen) {
return "max_tokens is larger than the maximum sequence length (see "
"configs.h).";
}
if (max_generated_tokens > max_tokens) {
return "Maximum number of generated tokens is larger than the maximum "
"total tokens.";
}
return nullptr;
}

template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(max_tokens, "max_tokens", size_t{3072},
"Maximum number of tokens in prompt + generation.");
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
"Maximum number of tokens to generate.");

visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
visitor(deterministic, "deterministic", false,
"Make top-k sampling deterministic", 2);
visitor(multiturn, "multiturn", false,
"Multiturn mode\n 0 = clear KV cache after every "
"interaction\n 1 = continue KV cache after every interaction\n "
" Default : 0 (conversation "
"resets every turn)");
}
};

} // namespace gcpp

#endif // THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_

0 comments on commit 0221956

Please sign in to comment.