-
Notifications
You must be signed in to change notification settings - Fork 1
/
prompt_store.py
137 lines (112 loc) · 4.36 KB
/
prompt_store.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from dataclasses import dataclass
import torch as t
from jaxtyping import Float, Int
from transformers import PreTrainedTokenizerBase
@dataclass
class PromptSet:
clean_prompt: str
correct_answer: str
incorrect_answer: str
corrupted_prompt: str
off_distribution_prompt: str = "When Owen and Henry went to the shops, Jack gave the bag to"
class PromptStore(list[PromptSet]):
def __init__(self, prompts: list[PromptSet], tokeniser: PreTrainedTokenizerBase):
super().__init__(prompts)
self.tokeniser = tokeniser
def str_to_token_tensors(self, text: list[str]) -> Int[t.Tensor, "examples seq_len"]:
return t.tensor(self.tokeniser(text, return_tensors="pt")["input_ids"])
@property
def clean_prompts(self) -> list[str]:
return [prompt.clean_prompt for prompt in self]
@property
def clean_tokens(self) -> Int[t.Tensor, "examples"]:
return self.str_to_token_tensors(self.clean_prompts)
@property
def corrupted_prompts(self) -> list[str]:
return [prompt.corrupted_prompt for prompt in self]
@property
def corrupted_tokens(self) -> Int[t.Tensor, "examples"]:
return self.str_to_token_tensors(self.corrupted_prompts)
@property
def off_distribution_prompts(self) -> list[str]:
return [prompt.off_distribution_prompt for prompt in self]
@property
def off_distribution_tokens(self) -> Int[t.Tensor, "examples"]:
return self.str_to_token_tensors(self.off_distribution_prompts)
@property
def correct_answers(self) -> list[str]:
return [prompt.correct_answer for prompt in self]
@property
def incorrect_answers(self) -> list[str]:
return [prompt.incorrect_answer for prompt in self]
@property
def answer_token_indices(self) -> Int[t.Tensor, "examples 2"]:
correct_answer_token_ids = self.str_to_token_tensors(self.correct_answers) # examples, 1
incorrect_answer_token_ids = self.str_to_token_tensors(
self.incorrect_answers
) # examples, 1
answer_token_indices = t.cat(
[correct_answer_token_ids, incorrect_answer_token_ids], dim=1
) # examples, 2
return answer_token_indices
def prepare_tokens_and_indices(
self,
) -> tuple[
Int[t.Tensor, "examples"],
Int[t.Tensor, "examples"],
Int[t.Tensor, "examples"],
Int[t.Tensor, "examples 2"],
]:
"""Prepare the tokens and indices for the IOI task.
Returns
-------
clean_tokens : Int[t.Tensor, "examples"]
corrupted_tokens : Int[t.Tensor, "examples"]
off_distribution_tokens : Int[t.Tensor, "examples"]
answer_token_indices : Int[t.Tensor, "examples, 2"]
"""
return (
self.clean_tokens,
self.corrupted_tokens,
self.off_distribution_tokens,
self.answer_token_indices,
)
def build_prompt_store(tokeniser: PreTrainedTokenizerBase) -> PromptStore:
return PromptStore(
[
PromptSet(
"When John and Mary went to the shops, John gave the bag to",
" Mary",
" John",
"When John and Mary went to the shops, Mary gave the bag to",
),
PromptSet(
"When Tom and James went to the park, James gave the ball to",
" Tom",
" James",
"When Tom and James went to the park, Tom gave the ball to",
),
PromptSet(
"When Dan and Sid went to the shops, Sid gave an apple to",
" Dan",
" Sid",
"When Dan and Sid went to the shops, Dan gave an apple to",
),
PromptSet(
"After Martin and Amy went to the park, Amy gave a drink to",
" Martin",
" Amy",
"After Martin and Amy went to the park, Martin gave a drink to",
),
],
tokeniser=tokeniser,
)
if __name__ == "__main__":
from transformers import GPT2Tokenizer
tokeniser = GPT2Tokenizer.from_pretrained("gpt2")
prompt_store = build_prompt_store(tokeniser)
clean_tokens, corrupted_tokens, off_distribution_tokens, answer_token_indices = (
prompt_store.prepare_tokens_and_indices()
)
print(clean_tokens)
print(off_distribution_tokens)