forked from karpathy/llama2.c
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_llama2c.c
145 lines (118 loc) · 5.11 KB
/
test_llama2c.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#include "llama2c.h"
/*
* This file tests both the usability of llama2c.h
* and it's correctness to spot any regressions.
*
* The actual test cases are retained from karpathy/llama2.c
*/
void assert_eq(int a, int b) {
if (a != b) {
printf("Assertion failed: %d != %d\n", a, b);
exit(EXIT_FAILURE);
}
}
void test_prompt_encoding(Tokenizer *tokenizer, char *prompt, int *expected_tokens, int num_expected_tokens) {
// encode
int *prompt_tokens = (int *) malloc((strlen(prompt) + 3) * sizeof(int));
int num_prompt_tokens = 0; // the total number of prompt tokens
encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
#if VERBOSITY == 1
// print maybe
printf("expected tokens:\n");
for (int i = 0; i < num_expected_tokens; i++) printf("%d ", expected_tokens[i]);
printf("\n");
printf("actual tokens:\n");
for (int i = 0; i < num_prompt_tokens; i++) printf("%d ", prompt_tokens[i]);
printf("\n");
#endif
// verify
assert_eq(num_prompt_tokens, num_expected_tokens);
for (int i = 0; i < num_prompt_tokens; i++) {
assert_eq(prompt_tokens[i], expected_tokens[i]);
}
#if VERBOSITY == 1
printf("OK\n");
printf("---\n");
#endif
free(prompt_tokens);
}
void test_prompt_encodings() {
// let's verify that the Tokenizer works as expected
char *tokenizer_path = "models/tokenizer.bin";
int vocab_size = 32000;
Tokenizer tokenizer;
build_tokenizer(&tokenizer, tokenizer_path, vocab_size);
// test 0 (test the empty string) (I added this as a simple case)
char *prompt0 = "";
int expected_tokens0[] = {1};
test_prompt_encoding(&tokenizer, prompt0, expected_tokens0, sizeof(expected_tokens0) / sizeof(int));
// the tests below are taken from the Meta Llama 2 repo example code
// https://github.com/facebookresearch/llama/blob/main/example_text_completion.py
// and the expected tokens come from me breaking in the debugger in Python
// test 1
char *prompt = "I believe the meaning of life is";
int expected_tokens[] = {1, 306, 4658, 278, 6593, 310, 2834, 338};
test_prompt_encoding(&tokenizer, prompt, expected_tokens, sizeof(expected_tokens) / sizeof(int));
// test 2
char *prompt2 = "Simply put, the theory of relativity states that ";
int expected_tokens2[] = {1, 3439, 17632, 1925, 29892, 278, 6368, 310, 14215, 537, 5922, 393, 29871};
test_prompt_encoding(&tokenizer, prompt2, expected_tokens2, sizeof(expected_tokens2) / sizeof(int));
// test 3
char *prompt3 = "A brief message congratulating the team on the launch:\n\n Hi everyone,\n\n I just ";
int expected_tokens3[] = {1, 319, 11473, 2643, 378, 629, 271, 18099, 278, 3815, 373, 278, 6826, 29901, 13, 13, 4706,
6324, 14332, 29892, 13, 13, 4706, 306, 925, 29871};
test_prompt_encoding(&tokenizer, prompt3, expected_tokens3, sizeof(expected_tokens3) / sizeof(int));
// test 4
char *prompt4 = "Translate English to French:\n\n sea otter => loutre de mer\n peppermint => menthe poivrée\n plush girafe => girafe peluche\n cheese =>";
int expected_tokens4[] = {1, 4103, 9632, 4223, 304, 5176, 29901, 13, 13, 4706, 7205, 4932, 357, 1149, 301, 449, 276,
316, 2778, 13, 4706, 1236, 407, 837, 524, 1149, 6042, 354, 772, 440, 29878, 1318, 13,
4706, 715, 1878, 330, 3055, 1725, 1149, 330, 3055, 1725, 4639, 28754, 13, 4706, 923, 968,
1149};
test_prompt_encoding(&tokenizer, prompt4, expected_tokens4, sizeof(expected_tokens4) / sizeof(int));
// memory and file handles cleanup
free_tokenizer(&tokenizer);
}
Llama2cConfig simple_config(){
Llama2cConfig config;
config.model_path = "./models/stories15M.bin"; // Requires a value from command line
config.tokenizer_path = "./models/tokenizer.bin";
config.temperature = 1.0f;
config.topp = 0.9f;
config.steps = 256;
config.prompt = "I believe the meaning of life is";
config.rng_seed = 0;
config.mode = "generate";
config.system_prompt = NULL;
return config;
}
void test_api_llama2c_encode() {
int *prompt_tokens = NULL;
int num_prompt_tokens = 0;
llama2c_encode(simple_config(), &prompt_tokens, &num_prompt_tokens);
int expected_tokens[] = {1, 306, 4658, 278, 6593, 310, 2834, 338};
assert_eq(num_prompt_tokens, 8);
for (int i = 0; i < num_prompt_tokens; i++) {
assert_eq(prompt_tokens[i], expected_tokens[i]);
}
free(prompt_tokens);
}
void test_api_llama2c_generate() {
Llama2cConfig config = simple_config();
config.prompt = "Hello world!";
config.steps = 10;
char *generated = llama2c_generate(config);
if (strncmp(generated, config.prompt, strlen(config.prompt)) != 0) {
printf("Generated text doesn't start with the prompt=\"%s\" as expected", config.prompt);
exit(EXIT_FAILURE);
}
}
void run_tests() {
test_prompt_encodings();
test_api_llama2c_encode();
test_api_llama2c_generate();
}
int main(int argc, char *argv[]) {
run_tests();
printf("ALL OK\n");
exit(0);
}