-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmeasure_pexplexity.py
107 lines (78 loc) · 3.82 KB
/
measure_pexplexity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# Measures perplexity and per-token latency of an RWKV model on a given text file.
# Perplexity is defined here as exp() of average cross-entropy loss.
# Usage: python measure_pexplexity.py C:\rwkv.cpp-169M.bin C:\text.txt 1024
import os
import time
import argparse
# TODO Get rid of this PyTorch dependency by writing a cross_entropy impl for numpy
import torch
from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model
from tokenizer_util import add_tokenizer_argument, get_tokenizer
from typing import List
def parse_args():
parser = argparse.ArgumentParser(description='Measure perplexity and per-token latency of an RWKV model on a given text file')
parser.add_argument('model_path', help='Path to model checkpoint file', type=str)
parser.add_argument('text_path', help='Path to text file in UTF-8 encoding', type=str)
parser.add_argument('ignore_first_n_tokens', help='How many tokens should be skipped before loss is measured', type=int)
parser.add_argument('token_limit', help='How many tokens to process; set to -1 to process all text', nargs='?', type=int, default=-1)
add_tokenizer_argument(parser)
return parser.parse_args()
args = parse_args()
print('Loading model')
model: rwkv_cpp_model.RWKVModel = rwkv_cpp_model.RWKVModel(
rwkv_cpp_shared_library.load_rwkv_shared_library(),
args.model_path
)
print('Loading text')
text: str = open(args.text_path, encoding='utf-8').read()
_, tokenizer_encode = get_tokenizer(args.tokenizer, model.n_vocab)
tokens: List[int] = tokenizer_encode(text)
token_count: int = len(tokens)
print(f'{token_count} tokens in the text')
token_limit: int = args.token_limit
assert token_limit == -1 or token_limit > 0, 'Invalid token_limit'
if token_limit != -1 and token_count > token_limit:
tokens = tokens[0:token_limit]
token_count = token_limit
print(f'Text was limited to {token_limit} tokens')
assert token_count - args.ignore_first_n_tokens > 1, 'Need at least 2 tokens for evaluation'
# ---
def format_loss(loss: torch.Tensor) -> str:
return str(['%.3f' % (loss[i].item(),) for i in range(len(loss))]).replace('\'', '')[1:-1]
def format_loss_with_perplexity(loss: torch.Tensor) -> str:
return f'loss [{format_loss(loss)}], perplexity {"%.3f" % (torch.exp(loss[0]).item(),)}'
# ---
logits, state = None, None
loss_sum: torch.Tensor = torch.tensor([0.0])
loss_count: int = 0
start: float = time.time()
run_count: int = token_count - 1
for i in range(run_count):
token: int = tokens[i]
target: int = tokens[i + 1]
logits, state = model.eval(token, state, state, logits)
if args.ignore_first_n_tokens == 0 or i + 1 >= args.ignore_first_n_tokens:
losses = torch.tensor([
torch.nn.functional.cross_entropy(logits, torch.tensor(target, dtype=torch.long), reduction='none').item()
])
loss_sum += losses
loss_count += 1
if run_count <= 5 or i % (run_count // 10) == 0:
avg_loss_so_far = loss_sum / loss_count
duration: float = time.time() - start
duration_per_token: float = duration / (i + 1)
runs_remaining: int = run_count - i - 1
duration_remaining: int = int(runs_remaining * duration_per_token)
print(f'Token #{i}/{token_count}, '
f'{int(100.0 * i / token_count)}%, '
f'ETA {duration_remaining // 60} m {duration_remaining % 60} s', end='')
if loss_count > 0:
print(f', averages so far: {format_loss_with_perplexity(avg_loss_so_far)}')
else:
print()
print()
print(f'Model: {os.path.basename(args.model_path)}, '
f'data: {os.path.basename(args.text_path)} with {token_count} tokens, '
f'skipped {args.ignore_first_n_tokens} tokens, '
f'averages: {format_loss_with_perplexity(loss_sum / loss_count)}, '
f'latency {int((time.time() - start) * 1000 / run_count)} ms per token')