Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
RaymondWang0 committed May 23, 2023
1 parent 9170525 commit 222ad20
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions experimental/transformer/src/OPTGenerate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@ void OPT_sample_repetition_penalty(OPT_token_data_array * candidates, const int
continue;
}

// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
if (candidates->data[i].logit <= 0) {
candidates->data[i].logit *= penalty;
} else {
}
else {
candidates->data[i].logit /= penalty;
}
}
Expand Down Expand Up @@ -70,7 +69,6 @@ void OPT_sample_temperature(OPT_token_data_array * candidates_p, float temp) {
//
// sampling
//

void OPT_sample_softmax(OPT_token_data_array * candidates) {
assert(candidates->size > 0);

Expand Down Expand Up @@ -122,7 +120,8 @@ void OPT_sample_top_k(OPT_token_data_array * candidates, int k, size_t min_keep)
};
if (k == (int) candidates->size) {
std::sort(candidates->data, candidates->data + candidates->size, comp);
} else {
}
else {
std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
}
candidates->sorted = true;
Expand Down Expand Up @@ -331,7 +330,6 @@ std::vector<int> OPTGenerate(std::vector<int> input_ids,
std::vector<int> embd;
std::vector<int> generate_ids;

//int n_past = 0;
int n_consumed = 0;
while ((int) input_ids.size() > n_consumed) {
embd.push_back(input_ids[n_consumed]);
Expand Down Expand Up @@ -397,7 +395,6 @@ std::vector<int> OPTGenerate(std::vector<int> input_ids,

int id = 0;
if (temp <= 0) {
// Greedy sampling
id = OPT_sample_token_greedy(&candidates_p);
}
else {
Expand All @@ -422,15 +419,12 @@ std::vector<int> OPTGenerate(std::vector<int> input_ids,
id = OPT_sample_token(&candidates_p);
}
}
// printf("`%d`", candidates_p.size);

// add it to the context
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);
embd.push_back(id);
generate_ids.push_back(id);

// decrement remaining sampling budget
--n_remain;
}

Expand Down

0 comments on commit 222ad20

Please sign in to comment.