Skip to content

Commit

Permalink
Merge pull request #728 from callummcdougall/callum/new-test-prompt
Browse files Browse the repository at this point in the history
`utils.test_prompt` compares multiple prompts
  • Loading branch information
bryce13950 authored Sep 26, 2024
2 parents 391fe55 + ad0dea0 commit 762bb5d
Showing 1 changed file with 59 additions and 21 deletions.
80 changes: 59 additions & 21 deletions transformer_lens/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ def remove_batch_dim(tensor: Float[torch.Tensor, "1 ..."]) -> Float[torch.Tensor

def test_prompt(
prompt: str,
answer: str,
answer: Union[str, list[str]],
model, # Can't give type hint due to circular imports
prepend_space_to_answer: bool = True,
print_details: bool = True,
Expand Down Expand Up @@ -728,7 +728,9 @@ def test_prompt(
answer:
The answer, e.g. "road". Note that if you set prepend_space_to_answer to False, you need
to think about if you have a space before the answer here (as e.g. in this example the
answer may really be " road" if the prompt ends without a trailing space).
answer may really be " road" if the prompt ends without a trailing space). If this is a
list of strings, then we only look at the next-token completion, and we compare them all
as possible model answers.
model:
The model.
prepend_space_to_answer:
Expand All @@ -748,44 +750,80 @@ def test_prompt(
Returns:
None (just prints the results directly).
"""
if prepend_space_to_answer and not answer.startswith(" "):
answer = " " + answer
answers = [answer] if isinstance(answer, str) else answer
n_answers = len(answers)
using_multiple_answers = n_answers > 1

if prepend_space_to_answer:
answers = [answer if answer.startswith(" ") else " " + answer for answer in answers]

# GPT-2 often treats the first token weirdly, so lets give it a resting position
prompt_tokens = model.to_tokens(prompt, prepend_bos=prepend_bos)
answer_tokens = model.to_tokens(answer, prepend_bos=False)
answer_tokens = model.to_tokens(answers, prepend_bos=False)

# If we have multiple answers, we're only allowed a single token generation
if using_multiple_answers:
answer_tokens = answer_tokens[:, :1]

# Deal with case where answers is a list of strings
prompt_tokens = prompt_tokens.repeat(answer_tokens.shape[0], 1)
tokens = torch.cat((prompt_tokens, answer_tokens), dim=1)

prompt_str_tokens = model.to_str_tokens(prompt, prepend_bos=prepend_bos)
answer_str_tokens = model.to_str_tokens(answer, prepend_bos=False)
answer_str_tokens_list = [model.to_str_tokens(answer, prepend_bos=False) for answer in answers]
prompt_length = len(prompt_str_tokens)
answer_length = len(answer_str_tokens)
answer_length = 1 if using_multiple_answers else len(answer_str_tokens_list[0])
if print_details:
print("Tokenized prompt:", prompt_str_tokens)
print("Tokenized answer:", answer_str_tokens)
logits = remove_batch_dim(model(tokens))
if using_multiple_answers:
print("Tokenized answers:", answer_str_tokens_list)
else:
print("Tokenized answer:", answer_str_tokens_list[0])
logits = model(tokens)
probs = logits.softmax(dim=-1)
answer_ranks = []

for index in range(prompt_length, prompt_length + answer_length):
answer_token = tokens[0, index]
answer_str_token = answer_str_tokens[index - prompt_length]
# Get answer tokens for this sequence position
answer_tokens = tokens[:, index]
answer_str_tokens = [a[index - prompt_length] for a in answer_str_tokens_list]
# Offset by 1 because models predict the NEXT token
token_probs = probs[index - 1]
sorted_token_probs, sorted_token_values = token_probs.sort(descending=True)
# Janky way to get the index of the token in the sorted list - I couldn't find a better way?
correct_rank = torch.arange(len(sorted_token_values))[
(sorted_token_values == answer_token).cpu()
].item()
answer_ranks.append((answer_str_token, correct_rank))
token_probs = probs[:, index - 1]
sorted_token_probs, sorted_token_positions = token_probs.sort(descending=True)
answer_token_ranks = sorted_token_positions.argsort(-1)[
range(n_answers), answer_tokens.cpu()
].tolist()
answer_ranks.append(
[
(answer_str_token, answer_token_rank)
for answer_str_token, answer_token_rank in zip(
answer_str_tokens, answer_token_ranks
)
]
)
if print_details:
# String formatting syntax - the first number gives the number of characters to pad to, the second number gives the number of decimal places.
# rprint gives rich text printing
rprint(
f"Performance on answer token:\n[b]Rank: {correct_rank: <8} Logit: {logits[index-1, answer_token].item():5.2f} Prob: {token_probs[answer_token].item():6.2%} Token: |{answer_str_token}|[/b]"
f"Performance on answer token{'s' if n_answers > 1 else ''}:\n"
+ "\n".join(
[
f"[b]Rank: {answer_token_ranks[i]: <8} Logit: {logits[i, index-1, answer_tokens[i]].item():5.2f} Prob: {token_probs[i, answer_tokens[i]].item():6.2%} Token: |{answer_str_tokens[i]}|[/b]"
for i in range(n_answers)
]
)
)
for i in range(top_k):
print(
f"Top {i}th token. Logit: {logits[index-1, sorted_token_values[i]].item():5.2f} Prob: {sorted_token_probs[i].item():6.2%} Token: |{model.to_string(sorted_token_values[i])}|"
f"Top {i}th token. Logit: {logits[0, index-1, sorted_token_positions[0, i]].item():5.2f} Prob: {sorted_token_probs[0, i].item():6.2%} Token: |{model.to_string(sorted_token_positions[0, i])}|"
)
rprint(f"[b]Ranks of the answer tokens:[/b] {answer_ranks}")

# If n_answers = 1 then unwrap answer ranks, so printed output matches original version of function
if not using_multiple_answers:
single_answer_ranks = [r[0] for r in answer_ranks]
rprint(f"[b]Ranks of the answer tokens:[/b] {single_answer_ranks}")
else:
rprint(f"[b]Ranks of the answer tokens:[/b] {answer_ranks}")


def transpose(tensor: Float[torch.Tensor, "... a b"]) -> Float[torch.Tensor, "... b a"]:
Expand Down

0 comments on commit 762bb5d

Please sign in to comment.