Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
RaymondWang0 committed May 23, 2023
1 parent 6eeced7 commit c8be5fe
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 45 deletions.
52 changes: 12 additions & 40 deletions experimental/transformer/include/OPTGenerate.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,46 +49,18 @@ struct opt_params {

// sampling parameters
std::unordered_map<int, float> logit_bias; // logit bias for specific tokens
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // 1.0 = disabled
float repeat_penalty = 1.10f; // 1.0 = disabled
int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float frequency_penalty = 0.00f; // 0.0 = disabled
float presence_penalty = 0.00f; // 0.0 = disabled
int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate

std::string model = "models/lamma-7B/ggml-model.bin"; // model path
std::string prompt = "";
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
std::string input_prefix = ""; // string to prefix user inputs with
std::string input_suffix = ""; // string to suffix user inputs with
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted

std::string lora_adapter = ""; // lora adapter path
std::string lora_base = ""; // base model path for the lora adapter

bool memory_f16 = true; // use f16 instead of f32 for memory kv
bool random_prompt = false; // do not randomize prompt if none provided
bool use_color = false; // use color to distinguish generations and inputs
bool interactive = false; // interactive mode
bool prompt_cache_all = false; // save user input and generations to prompt cache

bool embedding = false; // get only sentence embedding
bool interactive_first = false; // wait for user input immediately
bool multiline_input = false; // reverse the usage of `\`

bool instruct = false; // instruction mode (used for Alpaca models)
bool penalize_nl = true; // consider newlines as a repeatable token
bool perplexity = false; // compute perplexity over the prompt
bool use_mmap = true; // use mmap for faster loads
bool use_mlock = false; // use mlock to keep model in memory
bool mem_test = false; // compute maximum memory usage
bool verbose_prompt = false; // print prompt tokens before generation
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // 1.0 = disabled
float repeat_penalty = 1.10f; // 1.0 = disabled
int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float frequency_penalty = 0.00f; // 0.0 = disabled
float presence_penalty = 0.00f; // 0.0 = disabled
int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
};

void OPT_sample_repetition_penalty(OPT_token_data_array* candidates, const int* last_tokens, size_t last_tokens_size,
Expand Down
5 changes: 0 additions & 5 deletions experimental/transformer/src/OPTGenerate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,6 @@ std::vector<int> OPTGenerate(std::vector<int> input_ids,
const int mirostat = generation_config.mirostat;
const float mirostat_tau = generation_config.mirostat_tau;
const float mirostat_eta = generation_config.mirostat_eta;
const bool penalize_nl = generation_config.penalize_nl;
const int n_vocab = generation_config.n_vocab;

// Apply generation_config.logit_bias map
Expand All @@ -388,17 +387,13 @@ std::vector<int> OPTGenerate(std::vector<int> input_ids,
OPT_token_data_array candidates_p = { candidates.data(), candidates.size(), false };

// Apply penalties
//float nl_logit = logits[OPT_token_nl()];
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
OPT_sample_repetition_penalty(&candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, repeat_penalty);
OPT_sample_frequency_and_presence_penalties(&candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, alpha_frequency, alpha_presence);
/*if (!penalize_nl) {
logits[OPT_token_nl()] = nl_logit;
}*/

int id = 0;
if (temp <= 0) {
Expand Down

0 comments on commit c8be5fe

Please sign in to comment.