diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index 7e5828e14..421d35e15 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -681,7 +681,7 @@ def remove_batch_dim(tensor: Float[torch.Tensor, "1 ..."]) -> Float[torch.Tensor def test_prompt( prompt: str, - answer: str, + answer: Union[str, list[str]], model, # Can't give type hint due to circular imports prepend_space_to_answer: bool = True, print_details: bool = True, @@ -728,7 +728,9 @@ def test_prompt( answer: The answer, e.g. "road". Note that if you set prepend_space_to_answer to False, you need to think about if you have a space before the answer here (as e.g. in this example the - answer may really be " road" if the prompt ends without a trailing space). + answer may really be " road" if the prompt ends without a trailing space). If this is a + list of strings, then we only look at the next-token completion, and we compare them all + as possible model answers. model: The model. prepend_space_to_answer: @@ -748,44 +750,80 @@ def test_prompt( Returns: None (just prints the results directly). """ - if prepend_space_to_answer and not answer.startswith(" "): - answer = " " + answer + answers = [answer] if isinstance(answer, str) else answer + n_answers = len(answers) + using_multiple_answers = n_answers > 1 + + if prepend_space_to_answer: + answers = [answer if answer.startswith(" ") else " " + answer for answer in answers] + # GPT-2 often treats the first token weirdly, so lets give it a resting position prompt_tokens = model.to_tokens(prompt, prepend_bos=prepend_bos) - answer_tokens = model.to_tokens(answer, prepend_bos=False) + answer_tokens = model.to_tokens(answers, prepend_bos=False) + + # If we have multiple answers, we're only allowed a single token generation + if using_multiple_answers: + answer_tokens = answer_tokens[:, :1] + + # Deal with case where answers is a list of strings + prompt_tokens = prompt_tokens.repeat(answer_tokens.shape[0], 1) tokens = torch.cat((prompt_tokens, answer_tokens), dim=1) + prompt_str_tokens = model.to_str_tokens(prompt, prepend_bos=prepend_bos) - answer_str_tokens = model.to_str_tokens(answer, prepend_bos=False) + answer_str_tokens_list = [model.to_str_tokens(answer, prepend_bos=False) for answer in answers] prompt_length = len(prompt_str_tokens) - answer_length = len(answer_str_tokens) + answer_length = 1 if using_multiple_answers else len(answer_str_tokens_list[0]) if print_details: print("Tokenized prompt:", prompt_str_tokens) - print("Tokenized answer:", answer_str_tokens) - logits = remove_batch_dim(model(tokens)) + if using_multiple_answers: + print("Tokenized answers:", answer_str_tokens_list) + else: + print("Tokenized answer:", answer_str_tokens_list[0]) + logits = model(tokens) probs = logits.softmax(dim=-1) answer_ranks = [] + for index in range(prompt_length, prompt_length + answer_length): - answer_token = tokens[0, index] - answer_str_token = answer_str_tokens[index - prompt_length] + # Get answer tokens for this sequence position + answer_tokens = tokens[:, index] + answer_str_tokens = [a[index - prompt_length] for a in answer_str_tokens_list] # Offset by 1 because models predict the NEXT token - token_probs = probs[index - 1] - sorted_token_probs, sorted_token_values = token_probs.sort(descending=True) - # Janky way to get the index of the token in the sorted list - I couldn't find a better way? - correct_rank = torch.arange(len(sorted_token_values))[ - (sorted_token_values == answer_token).cpu() - ].item() - answer_ranks.append((answer_str_token, correct_rank)) + token_probs = probs[:, index - 1] + sorted_token_probs, sorted_token_positions = token_probs.sort(descending=True) + answer_token_ranks = sorted_token_positions.argsort(-1)[ + range(n_answers), answer_tokens.cpu() + ].tolist() + answer_ranks.append( + [ + (answer_str_token, answer_token_rank) + for answer_str_token, answer_token_rank in zip( + answer_str_tokens, answer_token_ranks + ) + ] + ) if print_details: # String formatting syntax - the first number gives the number of characters to pad to, the second number gives the number of decimal places. # rprint gives rich text printing rprint( - f"Performance on answer token:\n[b]Rank: {correct_rank: <8} Logit: {logits[index-1, answer_token].item():5.2f} Prob: {token_probs[answer_token].item():6.2%} Token: |{answer_str_token}|[/b]" + f"Performance on answer token{'s' if n_answers > 1 else ''}:\n" + + "\n".join( + [ + f"[b]Rank: {answer_token_ranks[i]: <8} Logit: {logits[i, index-1, answer_tokens[i]].item():5.2f} Prob: {token_probs[i, answer_tokens[i]].item():6.2%} Token: |{answer_str_tokens[i]}|[/b]" + for i in range(n_answers) + ] + ) ) for i in range(top_k): print( - f"Top {i}th token. Logit: {logits[index-1, sorted_token_values[i]].item():5.2f} Prob: {sorted_token_probs[i].item():6.2%} Token: |{model.to_string(sorted_token_values[i])}|" + f"Top {i}th token. Logit: {logits[0, index-1, sorted_token_positions[0, i]].item():5.2f} Prob: {sorted_token_probs[0, i].item():6.2%} Token: |{model.to_string(sorted_token_positions[0, i])}|" ) - rprint(f"[b]Ranks of the answer tokens:[/b] {answer_ranks}") + + # If n_answers = 1 then unwrap answer ranks, so printed output matches original version of function + if not using_multiple_answers: + single_answer_ranks = [r[0] for r in answer_ranks] + rprint(f"[b]Ranks of the answer tokens:[/b] {single_answer_ranks}") + else: + rprint(f"[b]Ranks of the answer tokens:[/b] {answer_ranks}") def transpose(tensor: Float[torch.Tensor, "... a b"]) -> Float[torch.Tensor, "... b a"]: