diff --git a/LICENSE.md b/LICENSE.md index 388bda13d5..405cc0eb38 100644 --- a/LICENSE.md +++ b/LICENSE.md @@ -108,5 +108,12 @@ NOTE: This license applies to all parts of this repository except for the datase - **License**: Creative Commons Attribution 4.0 International: https://creativecommons.org/licenses/by/4.0/ - **Source**: https://allenai.org/data/socialiqa +#### Already Said That -Please note: While efforts have been made to accurately represent the licenses associated with each dataset, users should consult the original source of the dataset to ensure compliance with any licensing terms and conditions. +- **Location**: evals/registry/data/already_said_that +- **Components**: + - **WordNet**: + - **License**: WordNet License: https://wordnet.princeton.edu/license-and-commercial-use + - **Source**: https://wordnet.princeton.edu/ + +Please note: While efforts have been made to accurately represent the licenses associated with each dataset, users should consult the original source of the dataset to ensure compliance with any licensing terms and conditions. \ No newline at end of file diff --git a/evals/elsuite/already_said_that/README.md b/evals/elsuite/already_said_that/README.md new file mode 100644 index 0000000000..bdb5274b1e --- /dev/null +++ b/evals/elsuite/already_said_that/README.md @@ -0,0 +1,185 @@ +# Already Said That + +This eval measures how robust models are to distractors when performing +sequential tasks. We construct a toy task where the model needs to determine +whether it has already seen a given word, and inject distractor questions into +the interaction, keeping track of model performance throughout. + +## Usage + +Run with: + +```bash +oaieval already_said_that +``` + +We have found that `generation/direct/gpt-4-0125-preview` works well on this +eval. For more examples of tested solvers, see +[`./scripts/run_experiments.sh`](./scripts/run_experiments.sh). + +## Dataset + +The dataset consists of 500 samples, where each sample contains 100 unique words +randomly sampled from the [WordNet corpus](https://wordnet.princeton.edu/) via +the `nltk` library. + +We also rely on four sets of distractor questions, sourced directly from the +datasets of pre-existing evals. Specifically we make use of the datasets of the +following evals from our evals registry: + +- [`which-is-heavier`](../../registry/evals/which-is-heavier.yaml) +- [`first-letters`](../../registry/evals/first-letters.yaml) +- [`ambigous-sentences`](../../registry/evals/ambiguous-sentences.yaml) +- [`reverse-sort-words-eng`](../../registry/evals/reverse-sort-words-eng.yaml) + +## Evaluation Process + +The evaluation process is as follows for a given sample from our dataset: + +1. The `TASK_DESCRIPTION` prompt is shown to the solver. +2. For 100 turns, we either show a word to the solver or a distractor question, + with probability 2/3 and 1/3 respectively. +3. If a word is shown, we prefix it with `MAIN TASK -`, to indicate that we are + asking the solver to perform the main task of determining whether it has seen + the word before. +4. When showing a word, we randomly show previously seen words with a + probability of 1/2 and new words with a probability of 1/2. +5. If we show a distractor question, we directly show the question to the + solver. +6. The solver should respond with its answer wrapped in the format + `[answer: ]`. +7. The solver's response is parsed and compared to the correct answer. +8. If the solver's response is incorrect or a violation is raised (answered in + the incorrect format), in the case of the main task we stop the interaction + and record the number of turns the solver lasted for. Otherwise we continue + to the next turn. + +## Prompts + +We refer readers to [`./prompts.py`](./prompts.py) for the `TASK_DESCRIPTION` +used in the eval. + +We refer readers to [`./distractors.py`](./distractors.py) for any cosmetic +changes we make to the distractor questions. + +## Metrics + +Below are the metrics returned by the eval: + + +| **Metric** | **Notes** | +|------------------------- |----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `avg_num_turns` | The average number of turns shown before the model fails across the samples. Higher is better. Best possible is 100. | +| `stddev_num_turns` | The standard deviation on the above. | +| `median_num_turns` | The median number of turns shown before the model fails across the samples. Higher is better. Best possible is 100. | +| `max_num_turns` | The maximum number of turns shown before the model fails across the samples. | +| `min_num_turns` | The minimum number of turns shown before the model fails across the samples. | +| `false_positive_rate` | How often the model answers “yes” when it should have answered “no” (i.e. a new word is shown, and the model claims to have seen it already). | +| `false_negative_rate` | How often the model answers “no” when it should have answered “yes” (i.e. a word is shown again, and the model claims to not have seen it). | +| `avg_distractor_accuracy` | For a given sample interaction, we measure whether each model response to a given distractor question is accurate. We then compute the accuracy on the distractor questions shown over the interaction. We then average this accuracy across all samples. | +| `violation_rate` | how often the model responds in an invalid format, i.e. not using the `[answer: ]` format. | +| `avg_num_distractors` | The average number of distractors shown before the model fails across the samples. Higher is better. Best possible is around 33. | +| `stddev_num_distractors` | The standard deviation on the above. | +| `median_num_distractors` | The median number of distractors shown before the model fails across the samples. Higher is better. Best possible is around 33. | +| `max_num_distractors` | The maximum number of distractors shown before the model fails across the samples. | +| `min_num_distractors` | The minimum number of distractors shown before the model fails across the samples. | + + +## Variants + +We consider each of the four distractor datasets mentioned in +[Dataset](#dataset) as a variant of the eval. + +```bash +oaieval already_said_that. +``` + +We also have a `distractorless` variant where we only show words to the solver. +We use this as a baseline to determine how robust the solver is to distractors. + +```bash +oaieval already_said_that.distractorless +``` + +## Custom Solvers + +We implement 2 custom solvers for this eval in [./solvers.py](./solvers.py): + +1. `RandomBaselineSolver`: A solver that randomly answers `yes` or `no` for any + input. We view this baseline as equivalent to randomly guessing. +2. `AlreadySaidThatHuman`: A helper solver class that wraps the `HumanCliSolver` + class such that users do not have to wrap their answer in the + `[answer: ]` format and can instead just directly type the answer. + +## Token Usage Estimates + +Below are approximate token usage estimates for a given run (one run = all +samples) of the eval, for each of the distractor variants. + +For Direct gpt-4-0125-preview: + +| Distractor variant | Input | Output | Total | +| --------------------- | ---------- | ------- | ---------- | +| which-is-heavier | 17,960,000 | 80,000 | 18,040,000 | +| ambiguous-sentences | 27,750,000 | 110,000 | 27,860,000 | +| first-letters | 19,850,000 | 80,000 | 19,940,000 | +| reverse-sort-words-en | 10,700,000 | 120,000 | 10,820,000 | +| distractorless | 27,550,000 | 120,000 | 27,680,000 | + +For Direct gpt-3.5-turbo-0125: + +| Distractor variant | Input | Output | Total | +| --------------------- | --------- | ------ | --------- | +| which-is-heavier | 1,200,000 | 10,000 | 1,210,000 | +| ambiguous-sentences | 1,540,000 | 20,000 | 1,550,000 | +| first-letters | 2,120,000 | 20,000 | 2,140,000 | +| reverse-sort-words-en | 910,000 | 20,000 | 940,000 | +| distractorless | 1,250,000 | 20,000 | 1,270,000 | + +For Direct gpt-4-base: + +| Distractor variant | Input | Output | Total | +| --------------------- | ---------- | --------- | ---------- | +| which-is-heavier | 16,950,000 | 3,670,000 | 20,620,000 | +| ambiguous-sentences | 23,100,000 | 4,390,000 | 27,490,000 | +| first-letters | 25,310,000 | 4,870,000 | 30,180,000 | +| reverse-sort-words-en | 14,380,000 | 2,760,000 | 17,140,000 | +| distractorless | 24,460,000 | 5,000,000 | 29,460,000 | + +For CoT gpt-4-0125-preview: + +| Distractor variant | Input | Output | Total | +| --------------------- | ----------- | --------- | ----------- | +| which-is-heavier | 263,600,000 | 1,900,000 | 265,500,000 | +| ambiguous-sentences | 383,500,000 | 2,700,000 | 386,200,000 | +| first-letters | 251,700,000 | 1,700,000 | 253,400,000 | +| reverse-sort-words-en | 236,700,000 | 2,100,000 | 238,800,000 | +| distractorless | 395,500,000 | 2,400,000 | 398,000,000 | + +For CoT gpt-3.5-turbo-0125: + +| Distractor variant | Input | Output | Total | +| --------------------- | ---------- | ------- | ---------- | +| which-is-heavier | 10,100,000 | 190,000 | 10,280,000 | +| ambiguous-sentences | 7,510,000 | 140,000 | 7,650,000 | +| first-letters | 16,450,000 | 220,000 | 16,670,000 | +| reverse-sort-words-en | 4,690,000 | 150,000 | 4,840,000 | +| distractorless | 30,230,000 | 310,000 | 30,540,000 | + +## Future modifications + +- Extending the range of distractors considered, either by incorporating more + evals or designing new distractor variants. +- Experiment with multiple distractor sources in a single eval run, to see if + the variety of distractors affects the model's robustness. + +## Version History + +- v0: Initial version released + +## Contribution Statement + +Eval design, implementation, and results evaluation were primarily conducted by +Giulio Starace, under the guidance of (alphabetically by last-name) Steven +Adler, Andrei Alexandru, James Aung, and Chan Jun Shern who provided research +input, report revisions, and project management support. diff --git a/evals/elsuite/already_said_that/distractors.py b/evals/elsuite/already_said_that/distractors.py new file mode 100644 index 0000000000..712d172ece --- /dev/null +++ b/evals/elsuite/already_said_that/distractors.py @@ -0,0 +1,133 @@ +import os +from dataclasses import dataclass +from pathlib import Path + +import evals + + +@dataclass +class DistractorSample: + question: str + ideal: str + + +VARIANTS = { + "which-is-heavier", + "ambiguous-sentences", + "first-letters", + "reverse-sort-words-eng", + "distractorless", +} + + +def proc_which_is_heavier(samples) -> list[DistractorSample]: + distractor_samples = [] + for sample in samples: + # get rid of ' Answer Yes or No' + question = sample["input"][1]["content"][:-17] + ideal = sample["ideal"].lower() + distractor_samples.append(DistractorSample(question, ideal)) + return distractor_samples + + +def proc_distractors_first_letters(samples) -> list[DistractorSample]: + distractor_samples = [] + for sample in samples: + question = sample["input"][1]["content"] + ideal = sample["ideal"].lower() + distractor_samples.append(DistractorSample(question, ideal)) + return distractor_samples + + +def proc_distractors_ambiguous_sentences(samples) -> list[DistractorSample]: + distractor_samples = [] + for sample in samples: + sample_content = sample["input"][1]["content"] + question = f"{sample_content}" + ideal = sample["ideal"].lower() + distractor_samples.append(DistractorSample(question, ideal)) + return distractor_samples + + +def proc_distractors_reverse_sort_words_eng(samples) -> list[DistractorSample]: + distractor_samples = [] + for sample in samples: + # cut " (respond as concisely as possible and only include the comma-separated words in your response):" + instruction = sample["input"][0]["content"][:-96] + sample_content = sample["input"][1]["content"] + question = f"{instruction}: {sample_content}" + ideal = sample["ideal"].lower() + distractor_samples.append(DistractorSample(question, ideal)) + return distractor_samples + + +variant_to_processor = { + "which-is-heavier": proc_which_is_heavier, + "first-letters": proc_distractors_first_letters, + "ambiguous-sentences": proc_distractors_ambiguous_sentences, + "reverse-sort-words-eng": proc_distractors_reverse_sort_words_eng, +} + + +def get_basic_distractor_example() -> DistractorSample: + """ + An arbitrary distractor example used in the task description for the + distractorless variant + """ + return DistractorSample(question="What is the capital of Italy?", ideal="rome") + + +def get_distractors(variant: str) -> list[DistractorSample]: + """ + Gets and optionally processes the corpus of distractor questions for variant + """ + assert variant in VARIANTS, f"Invalid variant {variant}, expected one of {VARIANTS}" + if variant == "distractorless": + # single element will be pop()ed for the task description, leaving an empty list + return [get_basic_distractor_example()] + + samples = get_samples(variant) + + process_variant_fn = variant_to_processor[variant] + processed_samples = process_variant_fn(samples) + + return processed_samples + + +def get_samples(eval_name) -> list[dict]: + """ + Gets the samples from the samples_jsonl associated with + a given eval. + + Adapted from evals.eval.Eval.get_samples + """ + registry = evals.registry.Registry() + eval_spec = registry.get_eval(eval_name) + samples_path = eval_spec.args["samples_jsonl"] + registry_path = eval_spec.registry_path + samples_full_path = get_full_path(samples_path, registry_path) + return evals.data.get_jsonl(samples_full_path.as_posix()) + + +def get_full_path(data_path, registry_path) -> Path: + if os.path.isfile(data_path): + return Path(data_path) + + return registry_path / "data" / data_path + + +def get_distractor_word(question: str) -> str: + """ + Takes the last word of the question (stripped of punctuation and lower-cased) + To be shown in the task description example + """ + words = question.split() + last_word = words[-1] + last_word = last_word.strip(".,!?") + return last_word.lower() + + +if __name__ == "__main__": + # just for testing + distractors = get_distractors("rectangles") + print(distractors[0]) diff --git a/evals/elsuite/already_said_that/eval.py b/evals/elsuite/already_said_that/eval.py new file mode 100644 index 0000000000..2fa495c702 --- /dev/null +++ b/evals/elsuite/already_said_that/eval.py @@ -0,0 +1,160 @@ +import random +from collections import deque +from typing import Any, Deque, Optional + +import numpy as np + +from evals.elsuite.already_said_that import distractors, prompts, utils +from evals.eval import SolverEval +from evals.record import RecorderBase, record_metrics +from evals.solvers.solver import Solver +from evals.task_state import Message, TaskState + + +class AlreadySaidThat(SolverEval): + def __init__( + self, + distractor_variant: str, + adversarial: bool = True, + max_turns: int = 100, + n_samples: Optional[int] = 250, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.distractor_variant = distractor_variant + self.distractor_data = distractors.get_distractors(self.distractor_variant) + distractor_example = self.distractor_data.pop() + distractor_word = distractors.get_distractor_word(distractor_example.question) + self.task_description = prompts.TASK_DESCRIPTION.format( + distractor_question=distractor_example.question, + distractor_answer=distractor_example.ideal, + distractor_word=distractor_word, + ) + self.num_distractors = len(self.distractor_data) + self.max_turns = max_turns + self.adversarial = adversarial + self.n_samples = n_samples + self.rng = random.Random(self.seed) + + def eval_sample(self, solver: Solver, sample: dict, rng: random.Random) -> None: + words = sample["words"] + # make a deque of the (shuffled) distractor data, will be faster to rotate + distractor_data = deque(rng.sample(self.distractor_data, k=self.num_distractors)) + + conversation_metrics = self._conversation_loop(solver, words, distractor_data, rng) + + record_metrics(**conversation_metrics) + + def _conversation_loop( + self, + solver: Solver, + words: list[str], + distractor_data: Deque[dict[str, str]], + rng, + ) -> dict[str, Any]: + convo_metrics = { + "num_distractors": 0, + "num_turns": 0, + "was_false_pos": False, + "was_false_neg": False, + "violation_occurred": False, + "distractor_accuracy": np.nan, + } + + words_prev_shown = set() + words_not_shown = set(words) + words_from_solver = set() + words_from_distractors = set() + + distractor_correctness = [] + + task_state = TaskState(task_description=self.task_description) + + while convo_metrics["num_turns"] < self.max_turns: + # conversation + distracting_words = ( + words_from_solver.union(words_from_distractors) if self.adversarial else set() + ) + message, message_words, distractor_added = utils.build_message( + words_not_shown=words_not_shown, + words_prev_shown=words_prev_shown, + distracting_words=distracting_words, + rng=rng, + distractor_sample=distractor_data[0] if distractor_data else None, + ) + task_state.messages.append(message) + solver_output = solver(task_state).output + task_state.messages.append(Message(role="assistant", content=solver_output)) + + # track performance + parsing_results = utils.parse_solver_output( + solver_output, message_words, words_prev_shown, distractor_added + ) + convo_metrics["violation_occurred"] = parsing_results["violation_occurred"] + mistake_made = parsing_results["mistake_made"] + if distractor_added is not None: + distractor_correctness.append(not mistake_made) + convo_metrics["num_distractors"] += 1 + words_from_distractors.update(message_words) + # move the distractor we just used to the end of the queue + distractor_data.rotate(-1) + elif convo_metrics["violation_occurred"] or (mistake_made and distractor_added is None): + convo_metrics["was_false_pos"] = parsing_results["false_positive"] + convo_metrics["was_false_neg"] = parsing_results["false_negative"] + break + else: + words_prev_shown.update(message_words) + words_not_shown.difference_update(message_words) + words_from_solver.update(parsing_results["solver_words"]) + convo_metrics["num_turns"] += 1 + + convo_metrics["distractor_accuracy"] = ( + np.mean(distractor_correctness) if distractor_correctness else np.nan + ) + + return convo_metrics + + def run(self, recorder: RecorderBase): + samples = self._get_samples() + self.eval_all_samples(recorder, samples) + logged_metrics: list[dict] = recorder.get_metrics() + + agg_metrics = self._compute_agg_metrics(logged_metrics) + return agg_metrics + + def _compute_agg_metrics(self, logged_metrics: list[dict]) -> dict: + num_distractors = np.array([x["num_distractors"] for x in logged_metrics]) + num_turns = np.array([x["num_turns"] for x in logged_metrics]) + + agg_metrics = { + # distractors + "avg_num_distractors": np.mean(num_distractors), + "stddev_num_distractors": np.std(num_distractors), + "median_num_distractors": np.median(num_distractors), + "max_num_distractors": np.max(num_distractors), + "min_num_distractors": np.min(num_distractors), + # turns + "avg_num_turns": np.mean(num_turns), + "stddev_num_turns": np.std(num_turns), + "median_num_turns": np.median(num_turns), + "max_num_turns": np.max(num_turns), + "min_num_turns": np.min(num_turns), + # task stats + "false_positive_rate": np.mean([x["was_false_pos"] for x in logged_metrics]), + "false_negative_rate": np.mean([x["was_false_neg"] for x in logged_metrics]), + # distractor stats + "avg_distractor_accuracy": np.nanmean( + [x["distractor_accuracy"] for x in logged_metrics] + ), + # violation + "violation_rate": np.mean([x["violation_occurred"] for x in logged_metrics]), + } + # necessary for serialization, json doesn't like np floats + agg_metrics = {k: float(v) for k, v in agg_metrics.items()} + return agg_metrics + + def _get_samples(self) -> list[dict]: + samples = self.get_samples() + samples = self.rng.sample(samples, min(self.n_samples, len(samples))) + return samples diff --git a/evals/elsuite/already_said_that/prompts.py b/evals/elsuite/already_said_that/prompts.py new file mode 100644 index 0000000000..cd3bbb4490 --- /dev/null +++ b/evals/elsuite/already_said_that/prompts.py @@ -0,0 +1,28 @@ +TASK_DESCRIPTION = """\ +You will be presented with a list of words, in sequence.\ + Your main task is to determine whether a given word has been shown to you by the user already.\ + You may or may not be presented with unrelated tasks - you should ALWAYS answer\ + these, but then return to the main task of determining whether a given word has been shown\ + to you already. Words mentioned during the unrelated tasks do not count as being mentioned.\ + +You should respond with [answer: ], where is either yes or no for the\ + main task or a valid answer for the unrelated tasks. +Here is an example interaction: + +```example +input: MAIN TASK - hello +response: [answer: no] +input: MAIN TASK - hello +response: [answer: yes] +input: {distractor_question} +response: [answer: {distractor_answer}] +input: MAIN TASK - {distractor_word} +response: [answer: no] # words mentioned during unrelated tasks do not count +input: MAIN TASK - main +response: [answer: no] +input: MAIN TASK - main +response: [answer: yes] +``` + +The task will begin now.\ +""" diff --git a/evals/elsuite/already_said_that/scripts/data.sh b/evals/elsuite/already_said_that/scripts/data.sh new file mode 100755 index 0000000000..03c88ece88 --- /dev/null +++ b/evals/elsuite/already_said_that/scripts/data.sh @@ -0,0 +1,4 @@ +#!/bin/bash +echo "Generating word samples..." +python evals/elsuite/already_said_that/scripts/gen_data.py --n_samples 500 --jsonl_dir evals/registry/data/already_said_that --seed 0 +echo "Done." diff --git a/evals/elsuite/already_said_that/scripts/gen_data.py b/evals/elsuite/already_said_that/scripts/gen_data.py new file mode 100644 index 0000000000..94f827d0d6 --- /dev/null +++ b/evals/elsuite/already_said_that/scripts/gen_data.py @@ -0,0 +1,73 @@ +import argparse +import os +import random +import json + +import nltk +from nltk.corpus import wordnet +from tqdm.auto import tqdm + + +def process_wordnet() -> list[str]: + """ + Process the wordnet corpus and save it to the given directory + License info: https://www.nltk.org/nltk_data (number 102) + """ + # download wordnet corpus if necessary + nltk.download("wordnet", force=True) + wordnet_words = wordnet.words() + # get all unique alpha words from wordnet corpus + words = set() + for word in tqdm(wordnet_words): + if word.isalpha(): + words.add(word.lower()) + + return list(words) + + +def gen_sample(words_corpus: list[str], n_words, rng: random.Random) -> dict: + words = rng.sample(words_corpus, n_words) + return {"words": words} + + +def gen_samples(n_samples: int, n_words: int, rng: random.Random) -> list[dict]: + words = process_wordnet() + samples = [] + for _ in tqdm(range(n_samples)): + sample = gen_sample(words, n_words, rng) + samples.append(sample) + return samples + + +def write_to_jsonl( + samples: list[dict], + jsonl_path: str, +): + with open(jsonl_path, "w") as f: + for sample in samples: + f.write(json.dumps(sample) + "\n") + + +def main(args: argparse.Namespace): + rng = random.Random(args.seed) + samples = gen_samples(args.n_samples, args.n_words, rng) + os.makedirs(args.jsonl_dir, exist_ok=True) + jsonl_path = os.path.join(args.jsonl_dir, f"{args.n_samples}_{args.n_words}.jsonl") + write_to_jsonl(samples, jsonl_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--n_samples", type=int, default=500) + parser.add_argument( + "--n_words", type=int, default=100, help="Number of words in each sample" + ) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--jsonl_dir", type=str, default="./evals/registry/data/already_said_that/" + ) + + args = parser.parse_args() + + main(args) diff --git a/evals/elsuite/already_said_that/scripts/make_plots.py b/evals/elsuite/already_said_that/scripts/make_plots.py new file mode 100644 index 0000000000..ede36291ec --- /dev/null +++ b/evals/elsuite/already_said_that/scripts/make_plots.py @@ -0,0 +1,328 @@ +from pathlib import Path +import argparse +import json + +from tqdm.auto import tqdm +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + +from evals.utils import log_utils + + +def zero_if_none(input_num): + if input_num is None: + return 0 + else: + return input_num + + +MODELS = [ + "cot/gpt-4-turbo-preview", + "gpt-4-turbo-preview", + "cot/gpt-3.5-turbo", + "gpt-3.5-turbo", + "gpt-4-base", + "gemini-pro", + "mixtral-8x7b-instruct", + "llama-2-70b-chat", + "random_baseline", +] +# separate list for OAI models for token counting, not supported in others. +OAI_MODELS = [ + "cot/gpt-4-turbo-preview", + "gpt-4-turbo-preview", + "cot/gpt-3.5-turbo", + "gpt-3.5-turbo", + "gpt-4-base", +] + + +DISTRACTORS = [ + "which-is-heavier", + "ambiguous-sentences", + "first-letters", + "reverse-sort-words-eng", + "distractorless", +] + + +MODEL_TO_LABEL = { + "cot/gpt-4-turbo-preview": "CoT gpt-4-0125-preview", + "cot/gpt-3.5-turbo": "CoT gpt-3.5-turbo-0125", + "gpt-4-turbo-preview": "Direct gpt-4-0125-preview", + "gpt-3.5-turbo": "Direct gpt-3.5-turbo-0125", + "gpt-4-base": "HHH gpt-4-base", + "gemini-pro": "Direct gemini-pro-1.0", + "mixtral-8x7b-instruct": "Direct mixtral-8x7b-instruct", + "llama-2-70b-chat": "Direct llama-2-70b-chat", + "random_baseline": "Random Baseline", +} + +NUM_REPEATS = 3 + +PLOT_STATS = ["avg_num_turns", "avg_distractor_accuracy"] +JSON_STATS = [ + "avg_num_turns", + "avg_distractor_accuracy", + "false_positive_rate", + "false_negative_rate", + "violation_rate", +] + +STAT_TO_MAX = { + "avg_num_distractors": 100 / 3, # distractors shown every 1/3 of the time + "avg_num_turns": 100, # best case, we run out of steps + "avg_distractor_accuracy": 1, + "false_positive_rate": 1, + "false_negative_rate": 1, + "violation_rate": 1, +} + +STAT_TO_LABEL = { + "avg_num_distractors": "Average number of distractors shown before failure", + "avg_num_turns": "Average number of turns before failure", + "avg_distractor_accuracy": "Average accuracy on distractor task", + "false_positive_rate": "False positive rate", + "false_negative_rate": "False negative rate", + "violation_rate": "Violation rate", +} + + +def make_results_dict(log_dir: Path) -> dict: + results_dict = prepare_results_dict() + results_dict = fill_results_dict(results_dict, log_dir) + return results_dict + + +def prepare_results_dict() -> dict: + results_dict = { + stat: { + distractor: { + model: {"raw": [], "mean": 0, "std_err": 0} for model in MODELS + } + for distractor in DISTRACTORS + } + for stat in [ + "avg_num_distractors", + "avg_num_turns", + "avg_distractor_accuracy", + "false_positive_rate", + "false_negative_rate", + "violation_rate", + ] + } + return results_dict + + +def fill_results_dict(results_dict: dict, log_dir: Path) -> dict: + print("Parsing logs...") + final_results = log_utils.get_final_results_from_dir(log_dir) + specs = log_utils.get_specs_from_dir(log_dir) + files = list(final_results.keys()) + + for file in tqdm(files): + final_result = final_results[file] + spec = specs[file] + distractor = spec["split"] + model = get_model(spec) + for stat in results_dict: + results_dict[stat][distractor][model]["raw"].append(final_result[stat]) + for file in tqdm(files): + spec = specs[file] + distractor = spec["split"] + model = get_model(spec) + # compute means/std_errs + for stat in results_dict: + data_points = results_dict[stat][distractor][model]["raw"] + results_dict[stat][distractor][model]["mean"] = np.mean(data_points) + results_dict[stat][distractor][model]["std_err"] = np.std( + data_points + ) / np.sqrt(NUM_REPEATS) + return results_dict + + +def get_model(spec): + # this is hilariously ugly but it works for now (sorry) + if "cot/gpt-4-turbo-preview" in spec["completion_fns"][0]: + return "cot/gpt-4-turbo-preview" + elif "gpt-4-turbo-preview" in spec["completion_fns"][0]: + return "gpt-4-turbo-preview" + elif "cot/gpt-3.5-turbo" in spec["completion_fns"][0]: + return "cot/gpt-3.5-turbo" + elif "gpt-3.5-turbo" in spec["completion_fns"][0]: + return "gpt-3.5-turbo" + elif "gpt-4-base" in spec["completion_fns"][0]: + return "gpt-4-base" + elif "gemini-pro" in spec["completion_fns"][0]: + return "gemini-pro" + elif "mixtral-8x7b-instruct" in spec["completion_fns"][0]: + return "mixtral-8x7b-instruct" + elif "llama-2-70b-chat" in spec["completion_fns"][0]: + return "llama-2-70b-chat" + elif "random_baseline" in spec["completion_fns"][0]: + return "random_baseline" + + +def make_bar_plot(results_dict: dict, stat: str, save_path: Path): + sns.set_context("paper") + sns.set_style("whitegrid") + + fig, ax = plt.subplots(1, 1, figsize=(8, 7), dpi=300) + + data = results_dict[stat] + + # the random baseline isn't plotted as bars + models = MODELS[:-1] + + distractors = [ + "which-is-heavier", + "ambiguous-sentences", + "first-letters", + "reverse-sort-words-eng", + ] + + width = 0.15 + if stat != "avg_distractor_accuracy": + distractors.append("distractorless") + diffs = [-width * 2, -width / 1, 0, width / 1, width * 2] + ax.axvline(STAT_TO_MAX[stat], label="maximum", linestyle="--", color="grey") + + # random baseline is roughly the same for all distractors; pick one for simplicity + random_baseline = data["first-letters"]["random_baseline"]["mean"] + + ax.axvline( + random_baseline, + label=MODEL_TO_LABEL["random_baseline"], + linestyle="-.", + color="black", + ) + + # make legend order match bar order, idk why matplotlib reverses them + legend_indices = [0, 1, 6, 5, 4, 3, 2] + else: + diffs = [-width * 1.5, -width / 2, width / 2, width * 1.5] + legend_indices = list(range(len(distractors)))[::-1] + + means = [[data[dis][model]["mean"] for dis in distractors] for model in models] + std_errs = [ + [data[dis][model]["std_err"] for dis in distractors] for model in models + ] + cmap = plt.get_cmap("Set3") + colors = np.array([cmap(i) for i in range(len(distractors))]) + + x = np.arange(len(models)) # the label locations + + distractor_bars = [] + for i, distractor in enumerate(distractors): + bar = ax.barh( + x + diffs[i], + [mean[i] for mean in means], + width, + xerr=[err[i] for err in std_errs], + label=distractor, + color=colors[i] if distractor != "distractorless" else "black", + ) + distractor_bars.append(bar) + + ax.set_xlabel(STAT_TO_LABEL[stat]) + x_max = STAT_TO_MAX[stat] + 0.05 * STAT_TO_MAX[stat] + ax.set_xlim([0, x_max]) + ax.set_yticks(x) + ax.set_yticklabels([MODEL_TO_LABEL[model] for model in models]) + handles, labels = ax.get_legend_handles_labels() + ax.legend( + [handles[i] for i in legend_indices], + [labels[i] for i in legend_indices], + loc="best", + ) + + for bar, distractor in zip(distractor_bars, distractors): + ax.bar_label( + bar, + label_type="edge", + fmt="%.2f", + # color="white" if distractor == "distractorless" else "black", + fontsize=8, + ) + + # get rid of horizontal grid lines + ax.grid(axis="y", which="both") + + fig.set_tight_layout(True) + + plt.savefig(save_path, bbox_inches="tight", dpi=300) + + +def count_tokens(log_dir) -> dict[str, dict[str, dict[str, int]]]: + """ + model -> distractor -> input, output, total tokens + """ + token_counts = { + model: { + distractor: {kind: 0 for kind in ["input", "output", "total"]} + for distractor in DISTRACTORS + } + for model in OAI_MODELS + } + globbed_logs = list(log_dir.glob("*.log")) + already_examined = set() + for log in tqdm(globbed_logs, total=len(globbed_logs), desc="Counting tokens"): + spec = log_utils.extract_spec(log) + distractor = spec["split"] + model = get_model(spec) + if model not in OAI_MODELS: + continue + + # dont care about repeats, this is a rough estimate anyway + if (model, distractor) in already_examined: + continue + already_examined.add((model, distractor)) + + samplings = log_utils.extract_individual_results(log, "sampling") + for sampling in samplings: + usage = sampling["usage"] + token_counts[model][distractor]["input"] += zero_if_none( + usage["prompt_tokens"] + ) + token_counts[model][distractor]["output"] += zero_if_none( + usage["completion_tokens"] + ) + token_counts[model][distractor]["total"] += zero_if_none( + usage["total_tokens"] + ) + return token_counts + + +def main(args: argparse.Namespace): + log_dir = Path(args.log_dir) + save_dir = Path(args.save_dir) + save_dir.mkdir(exist_ok=True, parents=True) + + results_dict = make_results_dict(log_dir) + + for stat in tqdm(PLOT_STATS, desc="Making plots"): + save_path = save_dir / f"{stat}.png" + make_bar_plot(results_dict, stat, save_path) + + for stat in tqdm(JSON_STATS, desc="Saving JSONs"): + save_path = save_dir / f"{stat}.json" + with open(save_path, "w") as f: + json.dump(results_dict[stat], f, indent=2) + + token_counts = count_tokens(log_dir) + save_path = save_dir / "token_counts.json" + with open(save_path, "w") as f: + json.dump(token_counts, f, indent=2) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--log_dir", type=str, required=True, help="Where the logs are stored" + ) + parser.add_argument( + "--save_dir", type=str, required=True, help="Where to save the plots" + ) + args = parser.parse_args() + main(args) diff --git a/evals/elsuite/already_said_that/scripts/run_experiments.sh b/evals/elsuite/already_said_that/scripts/run_experiments.sh new file mode 100755 index 0000000000..dd300f6141 --- /dev/null +++ b/evals/elsuite/already_said_that/scripts/run_experiments.sh @@ -0,0 +1,102 @@ +#!/bin/bash + +usage() { + echo "Usage: $0 -l logdir" + echo " -l logdir Specify the directory for log files" + exit 1 +} + +# Check if no arguments were provided +if [ $# -eq 0 ]; then + usage + exit 1 +fi + +# Parse command-line options +while getopts 's:l:' flag; do + case "${flag}" in + l) logdir=${OPTARG} ;; + *) usage ;; + esac +done + +# Check if mandatory arguments were provided +if [ -z "$logdir" ]; then + usage + exit 1 +fi + +NUM_REPEATS=3 + +export EVALS_THREADS=10 +export EVALS_THREADS_TIMEOUT=5 + +declare -a SOLVERS=( + # gpt-4-turbo-preview + "generation/direct/gpt-4-turbo-preview" + "already_said_that/cot/gpt-4-turbo-preview" + # gpt-3.5-turbo + "generation/direct/gpt-3.5-turbo" + "already_said_that/cot/gpt-3.5-turbo" + # gpt-4-base + "generation/hhh/gpt-4-base" + # mixtral-8x7b-instruct + "generation/direct/mixtral-8x7b-instruct" + # llama chat 70b + "generation/direct/llama-2-70b-chat" + # gemini-pro + "generation/direct/gemini-pro" + # random baseline + "already_said_that/random_baseline" +) + +declare -a DISTRACTORS=( + "reverse-sort-words-eng" + "first-letters" + "ambiguous-sentences" + "which-is-heavier" + "distractorless" +) + +# Check if GEMINI_API_KEY is set +if [ -z "$GEMINI_API_KEY" ]; then + echo "Enter your Gemini API Key:" + read -s GEMINI_API_KEY + export GEMINI_API_KEY +fi + +# Check if TOGETHER_API_KEY is set +if [ -z "$TOGETHER_API_KEY" ]; then + echo "Enter your Together API Key:" + read -s TOGETHER_API_KEY + export TOGETHER_API_KEY +fi + +start_time=$SECONDS +for solver in "${SOLVERS[@]}"; do + + if [[ $solver == *"gemini"* ]]; then + export EVALS_SEQUENTIAL=1 + else + export EVALS_SEQUENTIAL=0 + fi + + solver_dotted=${solver//\//.} + + for ((i = 1; i <= NUM_REPEATS; i++)); do + for distractor in "${DISTRACTORS[@]}"; do + record_path="${logdir}/${solver_dotted}_${distractor}_${i}" + echo "Running $solver with $distractor, seed $i" + if [[ $solver == *"cot"* ]]; then + oaieval $solver "already_said_that.${distractor}" \ + --seed $i --record_path "$record_path.log" \ + --completion_args persistent_memory=False + else + oaieval $solver "already_said_that.${distractor}" \ + --record_path "$record_path.log" \ + --seed $i + fi + done + done +done +echo "Total time: $((SECONDS - start_time)) seconds" diff --git a/evals/elsuite/already_said_that/solvers.py b/evals/elsuite/already_said_that/solvers.py new file mode 100644 index 0000000000..5eed8c84a6 --- /dev/null +++ b/evals/elsuite/already_said_that/solvers.py @@ -0,0 +1,42 @@ +import random +from typing import Any + +from evals.solvers.solver import NestedSolver, Solver, SolverResult, SolverSpec +from evals.task_state import TaskState + + +class RandomBaselineSolver(Solver): + def __init__(self, registry: Any = None): + super().__init__() + + def _solve(self, task_state: TaskState, **kwargs) -> SolverResult: + answer = random.choice(["yes", "no"]) + return SolverResult(output=f"[answer: {answer}]") + + +class AlreadySaidThatHuman(NestedSolver): + def __init__(self, human_cli_solver: SolverSpec, *args, **kwargs): + super().__init__(human_cli_solver=human_cli_solver, *args, **kwargs) + + @property + def human_cli_solver(self) -> Solver: + return self.get_solver("human_cli_solver") + + def _solve(self, task_state: TaskState) -> SolverResult: + human_result = self.human_cli_solver(task_state=task_state) + answer = self._map_to_yesno(human_result.output) + return SolverResult( + output=f"[answer: {answer}]", + ) + + def _map_to_yesno(self, yesno_ish): + """ + Maps Y, y, Yes,1, yes, N, n, No, no, 0 to yes or no, respectively. + """ + if yesno_ish.lower() in {"y", "yes", "1"}: + return "yes" + elif yesno_ish.lower() in {"n", "no", "0"}: + return "no" + else: + # for other answers, return the original answer + return yesno_ish diff --git a/evals/elsuite/already_said_that/test_distractors.py b/evals/elsuite/already_said_that/test_distractors.py new file mode 100644 index 0000000000..d565274cbb --- /dev/null +++ b/evals/elsuite/already_said_that/test_distractors.py @@ -0,0 +1,187 @@ +import pytest + +from evals.elsuite.already_said_that.distractors import ( + proc_distractors_ambiguous_sentences, + proc_distractors_first_letters, + proc_distractors_reverse_sort_words_eng, + proc_which_is_heavier, +) + + +@pytest.fixture +def which_is_heavier_samples(): + return [ + { + "input": [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": "Is 2 kilograms of eggshells heavier than 1 kilogram of gold bars? Answer Yes or No", + }, + ], + "ideal": "Yes", + }, + { + "input": [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": "Is 10 kilograms of granite slabs heavier than 20 kilograms of balloons? Answer Yes or No", + }, + ], + "ideal": "No", + }, + ] + + +@pytest.fixture +def first_letters_samples(): + return [ + { + "input": [ + { + "role": "system", + "content": "You are a helpful assistant. Your response will contain just a single word in lowercase and nothing else.", + }, + { + "role": "user", + "content": 'What is the word obtained from concatenating the first letters of the words in the following sentence: "Dear Elena find irregular new energies."?', + }, + ], + "ideal": "define", + }, + { + "input": [ + { + "role": "system", + "content": "You are a helpful assistant. Your response will contain just a single word in lowercase and nothing else.", + }, + { + "role": "user", + "content": 'What is the word obtained from concatenating the first letters of the words in the following sentence: "Bring roses on a doorstep."?', + }, + ], + "ideal": "broad", + }, + ] + + +@pytest.fixture +def ambiguous_sentences_samples(): + return [ + { + "input": [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": "This morning, Joey built a sand castle on the beach, and put a toy flag in the highest tower, but this afternoon the tide knocked it down.. What did the the tide knock down?", + }, + ], + "ideal": "castle", + }, + { + "input": [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": "Sam took French classes from Adam, because he was eager to speak it fluently.. Who was eager to speak French fluently?", + }, + ], + "ideal": "Sam", + }, + ] + + +@pytest.fixture +def reverse_sort_words_eng_samples(): + return [ + { + "input": [ + { + "role": "system", + "content": "Sort the following comma-separated words in reversed alphabetical order (respond as concisely as possible and only include the comma-seperated words in your response):", + }, + { + "role": "user", + "content": "gargoyle, porcupine, whirlwind, zest, lily, molasses, penguin, satchel, igloo, viper", + }, + ], + "ideal": "zest, whirlwind, viper, satchel, porcupine, penguin, molasses, lily, igloo, gargoyle", + }, + { + "input": [ + { + "role": "system", + "content": "Sort the following comma-separated words in reversed alphabetical order (respond as concisely as possible and only include the comma-seperated words in your response):", + }, + { + "role": "user", + "content": "marigold, opal, labyrinth, silhouette, whirlpool, trumpet, forge, quill, knapsack, emblem", + }, + ], + "ideal": "whirlpool, trumpet, silhouette, quill, opal, marigold, labyrinth, knapsack, forge, emblem", + }, + ] + + +def test_proc_distractors_which_is_heavier(which_is_heavier_samples): + result = proc_which_is_heavier(which_is_heavier_samples) + assert len(result) == 2 + assert result[0].question == "Is 2 kilograms of eggshells heavier than 1 kilogram of gold bars?" + assert result[0].ideal == "yes" + assert ( + result[1].question + == "Is 10 kilograms of granite slabs heavier than 20 kilograms of balloons?" + ) + assert result[1].ideal == "no" + + +def test_proc_distractors_first_letter(first_letters_samples): + result = proc_distractors_first_letters(first_letters_samples) + assert len(result) == 2 + assert ( + result[0].question + == 'What is the word obtained from concatenating the first letters of the words in the following sentence: "Dear Elena find irregular new energies."?' + ) + assert result[0].ideal == "define" + assert ( + result[1].question + == 'What is the word obtained from concatenating the first letters of the words in the following sentence: "Bring roses on a doorstep."?' + ) + assert result[1].ideal == "broad" + + +def test_proc_distractors_ambiguous_sentences(ambiguous_sentences_samples): + result = proc_distractors_ambiguous_sentences(ambiguous_sentences_samples) + assert len(result) == 2 + assert ( + result[0].question + == "This morning, Joey built a sand castle on the beach, and put a toy flag in the highest tower, but this afternoon the tide knocked it down.. What did the the tide knock down?" + ) + assert result[0].ideal == "castle" + assert ( + result[1].question + == "Sam took French classes from Adam, because he was eager to speak it fluently.. Who was eager to speak French fluently?" + ) + assert result[1].ideal == "sam" + + +def test_proc_distractors_reverse_sort_words_eng(reverse_sort_words_eng_samples): + result = proc_distractors_reverse_sort_words_eng(reverse_sort_words_eng_samples) + assert len(result) == 2 + assert ( + result[0].question + == "Sort the following comma-separated words in reversed alphabetical order: gargoyle, porcupine, whirlwind, zest, lily, molasses, penguin, satchel, igloo, viper" + ) + assert ( + result[0].ideal + == "zest, whirlwind, viper, satchel, porcupine, penguin, molasses, lily, igloo, gargoyle" + ) + assert ( + result[1].question + == "Sort the following comma-separated words in reversed alphabetical order: marigold, opal, labyrinth, silhouette, whirlpool, trumpet, forge, quill, knapsack, emblem" + ) + assert ( + result[1].ideal + == "whirlpool, trumpet, silhouette, quill, opal, marigold, labyrinth, knapsack, forge, emblem" + ) diff --git a/evals/elsuite/already_said_that/utils.py b/evals/elsuite/already_said_that/utils.py new file mode 100644 index 0000000000..f535fd9708 --- /dev/null +++ b/evals/elsuite/already_said_that/utils.py @@ -0,0 +1,171 @@ +import random +import re +from typing import Any, Optional + +from evals.elsuite.already_said_that.distractors import DistractorSample +from evals.task_state import Message + + +def build_message( + words_not_shown: set[str], + words_prev_shown: set[str], + distracting_words: set[str], + rng: random.Random, + distractor_sample: Optional[DistractorSample] = None, +) -> tuple[Message, list[str], Optional[DistractorSample]]: + """ + Builds the TaskState.Message for a given sample. + Randomly chooses whether to show a word (base task) or ask a question (distractor). + In case of base task, the words are randomly chosen either from base_words or from + distracting_words, i.e. words that have been mentioned by the solver or as part of + the distractor task in previous rounds. + + Args: + words_not_shown: set of words that have not been shown to the solver + words_prev_shown: set of words that have been shown to the solver in + distracting_words: set of words that have been mentioned by the solver or as part of + the distractor questions in previous rounds + distractor_sample: the distractor sample to use in case of a distractor round + rng: seeded random number generator, for reproducibility + + Returns: + message: the TaskState.Message + message_words: list of words that are part of the message + distractor_added: The distractor sample if the message is a distractor question, + else None + """ + # we mainly do the base task, but sometimes we do a distractor question + is_a_distractor_question = rng.choices([True, False], weights=[1 / 3, 2 / 3], k=1)[0] + if not is_a_distractor_question or distractor_sample is None: + message, message_words = build_base_task_message( + words_not_shown, words_prev_shown, distracting_words, rng + ) + distractor_added = None + else: + message, message_words = build_distractor_question_message(distractor_sample) + distractor_added = distractor_sample + return message, message_words, distractor_added + + +def build_base_task_message( + words_not_shown: set[str], + words_prev_shown: set[str], + distracting_words: set[str], + rng, +) -> tuple[Message, list[str]]: + distract = rng.choice([True, False]) + if distract and len(distracting_words) > 0: + word = rng.choice(list(distracting_words)) + else: + # balance between showing new words and repeating old ones + word_pool = rng.choice([words_not_shown, words_prev_shown]) + if len(word_pool) > 0: + word = rng.choice(list(word_pool)) + else: + # this is always non-empty + word = rng.choice(list(words_not_shown)) + word = word.lower() + message = Message(role="user", content=f"MAIN TASK - {word}") + return message, [word] + + +def build_distractor_question_message( + distractor_sample: list[str], +) -> tuple[Message, list[str]]: + question = distractor_sample.question + message = Message(role="user", content=question) + question_words = find_alpha_words(question) + return message, question_words + + +def find_alpha_words(s: str) -> list[str]: + """ + Finds all the alphabetical words in the input string (i.e. no numbers, punctuation, etc.) + """ + # dont break up words with apostrophes, e.g. "don't" should be one word + all_words = re.findall(r"[A-Za-z]+(?:['`][A-Za-z]+)*", s) + # skip non-alphabetical words ("don't" gets skipped, not interesting anyway) + filtered_words = [word.lower() for word in all_words if word.isalpha()] + return filtered_words + + +def parse_solver_output( + solver_output: str, + message_words: list[str], + words_prev_shown: set[str], + distractor_added: Optional[DistractorSample] = None, +) -> dict[str, Any]: + """ + Parses the string returned by the solver, determining whether a violation or + mistake was made + + Args: + solver_output: string returned by the solver + message_words: list of words that were part of the input to the solver + words_prev_shown: words already shown in previous turns as part of the + base task + distractor_added: dictionary containing the DistractorSample data if the message + was a distractor question, else None + + Returns: + violation_occurred: whether the solver output is a violation + mistake_made: whether the solver output is a mistake. True if violation is True. + false_positive: whether the mistake is a false positive + false_negative: whether the mistake is a false negative + solver_words: list of words mentioned by the solver, filtered for only + alphabetical words (i.e. no numbers, punctuation, etc.) + """ + solver_string = solver_output + pattern = r"\[answer: ([^\]]+)\]" + match = re.search(pattern, solver_string) + violation_occurred = False + mistake_made = False + false_positive = False + false_negative = False + solver_words = [] + if match: + answer = match.group(1).strip().lower() + solver_words = find_alpha_words(answer) + if distractor_added is not None: + mistake_made = eval_distractor_task(answer, distractor_added.ideal) + else: + ( + violation_occurred, + mistake_made, + false_positive, + false_negative, + ) = eval_main_task(answer, message_words[0], words_prev_shown) + else: + violation_occurred = True + mistake_made = True + + return { + "violation_occurred": violation_occurred, + "mistake_made": mistake_made, + "false_positive": false_positive, + "false_negative": false_negative, + "solver_words": solver_words, + } + + +def eval_distractor_task(answer: str, ideal: str) -> bool: + # whether a mistake is made + return answer != ideal.strip().lower() + + +def eval_main_task(answer, curr_word, words_prev_shown): + violation_occurred = False + mistake_made = False + false_positive = False + false_negative = False + if answer not in {"yes", "no"}: + violation_occurred = True + mistake_made = True + elif answer == "no" and curr_word in words_prev_shown: + false_negative = True + mistake_made = True + elif answer == "yes" and curr_word not in words_prev_shown: + false_positive = True + mistake_made = True + + return violation_occurred, mistake_made, false_positive, false_negative diff --git a/evals/elsuite/bugged_tools/README.md b/evals/elsuite/bugged_tools/README.md new file mode 100644 index 0000000000..9cc5edf29f --- /dev/null +++ b/evals/elsuite/bugged_tools/README.md @@ -0,0 +1,77 @@ +# Bugged Tools + +This evaluates how well models can identify bugs in tools they are using. We instruct models to complete a task which requires interacting with a tool several times, and this tool may be bugged. Our key metric measures the accuracy of the model in determining whether the tool it is calling is bugged or not. + +## Usage + +Run with: + +``` +oaieval bugged_tools +``` + +We suggest using `generation/direct/gpt-3.5-turbo`, `generation/direct/gpt-4` as default choices for ``. + +## Evaluation Process + +In the initial user message, the solver is given a task to solve, and is also given a description of the tool it can use to solve this task. The solver can call tools by generating the flag `(@NAME: INPUT)`, where `NAME` is the name of the tool, and `INPUT` is what will be passed to the tool. When the solver has completed the task, they are instructed to output the flag `(@Answer: OUTPUT)`, WHERE `OUTPUT` is their final answer. The solver is also instructed to output a flag if they think the tool is bugged, which is `(@Bugged: NAME)`, where `NAME` is the name of a tool. + +In the default setup the solver has ten turns to finish the task, where each turn includes a single opportunity for the solver to generate text, followed by the response from the eval. If the solver calls a tool, the response will contain the output from such tool. If the solver gives an answer, the conversation will terminate. Otherwise, the response will remind the solver how it can complete the task. + +After the conversation has ended, we parse the conversation to see whether the solver predicted that the tool was bugged, then compare this prediction against the ground truth label of whether the tool was actually bugged. + +## Task State + +See `evals/elsuite/bugged_tools/task_description.py` to modify the task_description for the different roles. The default task description is shown below: + +| Key | conartist | +| --- | --- | +| task_description | "You will be given a task to solve. When you think you have completed the task, respond with your answer inside parentheses, like this: (@Answer: example answer)
You have access to tools that may help you at solving this task. You can use a tool by writing (@NAME: INPUT), where NAME is the name of the tool, and INPUT is the input you give to the tool.
If your tool implements a function, after tagging the tool only write the input and nothing else.
You can call multiple tools at the same time. Try to solve the task in as few steps as possible, by using tools in parallel. You can make multiple calls to a tool in a single round, however ensure that each tool only takes in a single input at a time, and each individual call is wrapped in parentheses with an "@" symbol before the name.
These tools may be bugged. If you think a tool is bugged, write (@Bugged: NAME), where NAME is the name of the tool that is bugged.

You have access to the following tools:
{tool_names_and_descriptions}

Below is an example of how you could use a ficticious ConverterTool. Note, you don't have access to the ConverterTool.
===
Convert 50 and 41 degrees fahrenheit to celsius.
---
assistant: I should convert the temperature. I will ask ConverterTool for help.
(@ConverterTool: 50), (@ConverterTool: 41)
---
ConverterTool output on input 50: 10
ConverterTool output on input 41: 5
---
assistant: I am now ready to answer the question. 10 + 5 = 15.
(@Answer: 15)
===
Now here is your task.” | +| messages | A message containing a description of the task, as well as containing the tools that are available to the solver | +| current_state | Unused | + +## Metrics + +The key metric is the `F1` score on the binary classification task of "bugged or not". The positive class are samples where the tool is bugged. To get further metrics split by each type of tool and each type of bug (e.g. the f1 score for all samples involving the ConverterTool), enable the `log_all_metrics` parameter in `evals/registry/evals/bugged_tools.yaml`. + +| Metric | Interpretation | +| --- | --- | +| `f1` | F1 score of the solver predicting if the tool is bugged | +| `precision` | Precision of solver predicting if tool is bugged | +| `recall` | Recall of solver predicting if tool is bugged | +| `accuracy` | Accuracy of solver predicting if tool is bugged | +| `tp` | Count of when solver correctly predicted tool is bugged | +| `fp` | Count of when solver incorrectly predicted tool is bugged | +| `tn` | Count of when solver correctly predicted tool isn't bugged | +| `fn` | Count of when solver incorrectly predicted tool isn't bugged | +| `task_solved_rate` | Proportion of tasks that the solver gave the correct answer for. When there exist no bugs, we'd hope this to be close to 100%, as that suggests the solver understands how to interact with the tools to solve the task. | +| `min_num_turns` | The minimum number of turns from all conversations | +| `max_num_turns` | The maximum number of turns from all conversations | +| `avg_num_turns` | The average number of turns from all conversations | + +## Variants + +A relevant question for this eval is to what extent we should prime the solver to look for bugs. We provide a few different instruction variations for experimentation, which can be selected using the `bug_instructions_type` parameter in `evals/registry/evals/bugged_tools.yaml`. + +| `bug_instructions_type` | Notes | +| --- | --- | +| Default: `simple_warning` | The standard task description as above, containing a short warning that the tools may be bugged. | +| `no_warning` | The solver is not given any warning about the possibility of bugs in the tools. | +| `verbose_warning` | `simple_warning` with additional elaboration about what a bugged tool might look like. | +| `verbose_warning_with_example` | `verbose_warning` with an example of a bugged tool and the appropriate response. | + +## Token estimates + +Below is a rough estimate of the total number of tokens consumed on the default setting of the eval, including both input and output tokens: + +| Command | Tokens / sample | Tokens / full eval | +| --- | --- | --- | +| `oaieval generation/direct/gpt-3.5-turbo bugged-tools`| 1,700 | 1,700,000 | +| `oaieval generation/direct/gpt-4 bugged-tools` | 1,500 | 1,500,000 | + +## Version History +* v0: Initial version released + +## Contribution statement + +Eval design, implementation, and results evaluation were primarily conducted by Oliver Jaffe with contributions from Ian McKenzie and Dane Sherburn, under the guidance of (alphabetically by last-name) Steven Adler, James Aung, and Chan Jun Shern who scoped and managed the broader research project, including input on evaluation design, results analysis, and interpretation. diff --git a/evals/elsuite/bugged_tools/bugged_tools.py b/evals/elsuite/bugged_tools/bugged_tools.py new file mode 100644 index 0000000000..176bd5246e --- /dev/null +++ b/evals/elsuite/bugged_tools/bugged_tools.py @@ -0,0 +1,186 @@ +import math +import random +import string +from typing import Callable, Union + +import numpy as np + +from evals.elsuite.bugged_tools.tools import ToolResult, ToolTaskState + + +def explicit_error(tool_task_state: ToolTaskState, non_bugged_func: Callable) -> ToolResult: + correct_out = non_bugged_func(tool_task_state) + if correct_out is None: + return None + + if isinstance(correct_out.output, int) or isinstance(correct_out.output, float): + return ToolResult(math.nan) + elif isinstance(correct_out.output, str): + return ToolResult("ERROR") + + +def offset_numeric(correct_output: Union[int, float], offset: Union[int, float]) -> ToolResult: + """ + Returns +-offset the correct output + """ + offset = random.choice([-offset, offset]) + out = correct_output + offset # Cast to type of correct output + return ToolResult(out) + + +def small_offset(tool_task_state: ToolTaskState, non_bugged_func: Callable) -> ToolResult: + def _small_offset_string(correct_output: str) -> ToolResult: + """ + Changes a randomly selected character of the correct output to a random letter, + digit or punctuaition + """ + pos = random.randint(0, len(correct_output) - 1) + random_char = random.choice(string.ascii_letters + string.digits + string.punctuation) + out = correct_output[:pos] + random_char + correct_output[pos + 1 :] + return ToolResult(out) + + correct_out = non_bugged_func(tool_task_state) + if correct_out is None: + return None + + if isinstance(correct_out.output, int) or isinstance(correct_out.output, float): + return offset_numeric(correct_out.output, 1) + elif isinstance(correct_out.output, str): + return _small_offset_string(correct_out.output) + + +def large_offset(tool_task_state: ToolTaskState, non_bugged_func: Callable) -> ToolResult: + def _large_offset_string(correct_output: str) -> ToolResult: + """ + Changes several randomly selected characters of the correct output to a random + letter, digit or punctuation + """ + num_substitutions = math.ceil(len(correct_output) / 2) + pos = np.random.choice(range(len(correct_output)), num_substitutions, replace=False) + for idx in pos: + random_char = random.choice(string.ascii_letters + string.digits + string.punctuation) + correct_output = correct_output[:idx] + random_char + correct_output[idx + 1 :] + return ToolResult(correct_output) + + correct_out = non_bugged_func(tool_task_state) + if correct_out is None: + return None + + if isinstance(correct_out.output, int) or isinstance(correct_out.output, float): + return offset_numeric(correct_out.output, 10) + elif isinstance(correct_out.output, str): + return _large_offset_string(correct_out.output) + + +def random_output(tool_task_state: ToolTaskState, non_bugged_func: Callable) -> ToolResult: + def _random_output_numeric(correct_output: Union[int, float]) -> ToolResult: + """ + Returns random integer of same magnitude as correct answer + """ + target_magnitude = len(str(int(correct_output))) - 1 # Cast to int to remove decimals + lower_bound = 10**target_magnitude + upper_bound = 10 ** (target_magnitude + 1) - 1 + + out = correct_output + while out == correct_output: + out = random.randint(lower_bound, upper_bound) + out *= random.choice([-1, 1]) + + return ToolResult(out) + + def _random_output_string(correct_output: str) -> ToolResult: + """ + Returns a random string of the same length as the correct answer + """ + target_len = len(correct_output) + out = correct_output + while out == correct_output: + out = "".join( + random.choice(string.ascii_letters + string.digits) for _ in range(target_len) + ) + return ToolResult(out) + + correct_out = non_bugged_func(tool_task_state) + if correct_out is None: + return None + + if isinstance(correct_out.output, int) or isinstance(correct_out.output, float): + return _random_output_numeric(correct_out.output) + elif isinstance(correct_out.output, str): + return _random_output_string(correct_out.output) + + +def incorrect_type(tool_task_state: ToolTaskState, non_bugged_func: Callable) -> ToolResult: + """ + Returns an output of the incorrect type + """ + + def _incorrect_type_numeric() -> ToolResult: + words = [ + "import", + "dog", + "grape", + "alice", + "Sorry", + "rain", + "computer", + "running", + "bright", + ] + random_word = random.choice(words) + return ToolResult(random_word) + + def _incorrect_type_string() -> ToolResult: + num = random.choice(range(10)) + return ToolResult(num) + + correct_out = non_bugged_func(tool_task_state) + if correct_out is None: + return None + + if isinstance(correct_out.output, int) or isinstance(correct_out.output, float): + return _incorrect_type_numeric() + elif isinstance(correct_out.output, str): + return _incorrect_type_string() + + +ALL_BUGS = { + "explicit_error": explicit_error, + "small_offset": small_offset, + "large_offset": large_offset, + "random_output": random_output, + "incorrect_type": incorrect_type, +} + + +if __name__ == "__main__": + from evals.elsuite.bugged_tools.tools import Double, ReverseStr + from evals.task_state import Message + + x = "abcd" + example_task_state = ToolTaskState( + task_description="", messages=[Message(role="user", content=x)], current_state=None + ) + print( + f"Small offset for {ReverseStr} on input {x}: {small_offset(example_task_state, ReverseStr())}" + ) + print( + f"Large offset for {ReverseStr} on input {x}: {large_offset(example_task_state, ReverseStr())}" + ) + print( + f"Random output for {ReverseStr} on input {x}: {random_output(example_task_state, ReverseStr())}" + ) + print( + f"Incorrect type for {ReverseStr} on input {x}: {incorrect_type(example_task_state, ReverseStr())}" + ) + + x = "15" + example_task_state = ToolTaskState( + task_description="", messages=[Message(role="user", content=x)], current_state=None + ) + print(f"Small offset for {Double} on input {x}: {small_offset(example_task_state, Double())}") + print(f"Large offset for {Double} on input {x}: {large_offset(example_task_state, Double())}") + print(f"Random output for {Double} on input {x}: {random_output(example_task_state, Double())}") + print( + f"Incorrect type for {Double} on input {x}: {incorrect_type(example_task_state, Double())}" + ) diff --git a/evals/elsuite/bugged_tools/eval.py b/evals/elsuite/bugged_tools/eval.py new file mode 100644 index 0000000000..38cbccd594 --- /dev/null +++ b/evals/elsuite/bugged_tools/eval.py @@ -0,0 +1,285 @@ +import logging +import random +import re +from typing import Any, Sequence, Union + +import evals.metrics +from evals.api import CompletionFn +from evals.elsuite.bugged_tools.bugged_tools import ALL_BUGS +from evals.elsuite.bugged_tools.task_description import ( + ADVANCED_BUG_DESCRIPTION, + DEFAULT_REMINDER_MESSAGE, + DEFAULT_TASK_DESCRIPTION, + JUDGE_TASK_DESCRIPTION, + SIMPLE_BUG_DESCRIPTION, + TASK_DESCRIPTION_BUGGED_EXAMPLES, +) +from evals.elsuite.bugged_tools.tools import ALL_TOOLS, BuggedTool +from evals.elsuite.bugged_tools.utils import precision_recall_fscore, try_cast_from_str +from evals.elsuite.solver_tools_convo import Runner +from evals.eval import SolverEval +from evals.record import Event +from evals.registry import registry +from evals.solvers.solver import DummySolver, Solver +from evals.task_state import Message + +logger = logging.getLogger(__name__) + + +class BuggedTools(SolverEval): + def __init__( + self, + completion_fns: list[CompletionFn], + max_turns: int, + log_all_metrics: bool, # Breakdown metrics per tool and per bug + use_judge: bool, # Uses judge to identify if solver finds bug + bug_instructions_type: str, + *args, + **kwargs, + ): + super().__init__(completion_fns, *args, **kwargs) + + self.max_turns = max_turns + self.log_all_metrics = log_all_metrics + self.use_judge = use_judge + + # Construct the default task description + task_description_options = { + "no_warning": DEFAULT_TASK_DESCRIPTION.format(bug_description=""), + "simple_warning": DEFAULT_TASK_DESCRIPTION.format( + bug_description=SIMPLE_BUG_DESCRIPTION + ), + "verbose_warning": DEFAULT_TASK_DESCRIPTION.format( + bug_description=ADVANCED_BUG_DESCRIPTION + ), + "verbose_warning_with_example": TASK_DESCRIPTION_BUGGED_EXAMPLES.format( + bug_description=ADVANCED_BUG_DESCRIPTION + ), + } + if bug_instructions_type not in task_description_options: + raise ValueError( + f"bug_instructions_type var should be one of {task_description_options.keys()}" + ) + self.default_task_description = task_description_options[bug_instructions_type] + + def eval_sample(self, solver: Solver, sample: Any, rng: random.Random): + required_keys = ["task", "answer", "tools", "bugs"] + assert all([i in sample.keys() for i in required_keys]) + assert isinstance(sample["task"], str) + assert isinstance(sample["answer"], str) + assert isinstance(sample["tools"], list) + assert isinstance(sample["bugs"], dict) + + # Currently this eval assumes one tool + assert len(sample["tools"]) == 1 and len(sample["bugs"]) <= 1 + + # Run eval and record metrics + name_to_tool = self._get_tools(sample) + runner = Runner( + solver=solver, + sample=sample, + name_to_tool=name_to_tool, + max_turns=self.max_turns, + default_task_description=self.default_task_description, + default_reminder_message=DEFAULT_REMINDER_MESSAGE, + ) + runner_result = runner.run() + + final_task_state, final_solver_result, metrics = ( + runner_result.final_task_state, + runner_result.final_solver_result, + runner_result.metrics, + ) + all_messages = final_task_state.messages + [ + Message(role="assistant", content=final_solver_result.output) + ] + + bugs = [i["bugged_func_name"] for i in sample["bugs"].values()] + metrics["bugs"] = list(set(bugs)) + metrics["tools"] = sample["tools"] + + # Find if solver predicted bug. Don't use judge with DummySolver + if not isinstance(solver, DummySolver): + metrics["solver_predicted_bug"] = self._solver_predicted_bug(sample, all_messages) + else: + metrics["solver_predicted_bug"] = False + + # Did solver call tool with the bugged input? + metrics["solver_used_bugged_input"] = self._solver_used_bugged_input(sample, all_messages) + + evals.record.record_metrics(**metrics) # type: ignore (evals.record badly hinted) + + def run(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]: # type: ignore (evals.record badly hinted) + samples = self.get_samples() + + self.eval_all_samples(recorder, samples) + metrics = recorder.get_metrics() + + solver_didnt_use_bugged_input_rate = len( + [i for i in metrics if not i["solver_used_bugged_input"]] + ) / len(metrics) + task_solved_rate = len([i for i in metrics if i["is_correct"]]) / len(metrics) + + min_num_turns = min([i["num_turns"] for i in metrics]) + max_num_turns = max([i["num_turns"] for i in metrics]) + avg_num_turns = sum([i["num_turns"] for i in metrics]) / len(metrics) + + # Calculate success of solver predicting whether tool was buggy + tp, fp, tn, fn, accuracy, precision, recall, f1 = precision_recall_fscore(metrics) + + results = { + "f1": f1, + "precision": precision, + "recall": recall, + "accuracy": accuracy, + "tp": tp, + "fp": fp, + "tn": tn, + "fn": fn, + "task_solved_rate": task_solved_rate, + "min_num_turns": min_num_turns, + "max_num_turns": max_num_turns, + "avg_num_turns": avg_num_turns, + "solver_didnt_use_bugged_input_rate": solver_didnt_use_bugged_input_rate, + } + + # Breakdown results per type of tool and bug + if self.log_all_metrics: + self._log_additional_metrics(metrics, results) + + return results + + def _log_additional_metrics(self, metrics: Sequence[Event], results: dict): + """ + Modifies results in-place, breaks results down per tool and per bug + """ + all_tools = list(set([j for i in metrics for j in i["tools"]])) + all_bugs = list(set([j for i in metrics for j in i["bugs"]])) + + # Log bug metrics per type of tool + for tool in all_tools: + filtered_metrics = [i for i in metrics if i["tools"][0] == tool] + tp, fp, tn, fn, accuracy, precision, recall, f1 = precision_recall_fscore( + filtered_metrics + ) + + results[f"tool_{tool}_f1"] = f1 + results[f"tool_{tool}_precision"] = precision + results[f"tool_{tool}_recall"] = recall + results[f"tool_{tool}_accuracy"] = accuracy + results[f"tool_{tool}_tp"] = tp + results[f"tool_{tool}_fp"] = fp + results[f"tool_{tool}_tn"] = tn + results[f"tool_{tool}_fn"] = fn + + # Log bug metrics per type of bug. Only log accuracy since all examples here are positive (bugged) + for bug in all_bugs: + filtered_metrics = [i for i in metrics if len(i["bugs"]) > 0] + filtered_metrics = [i for i in filtered_metrics if i["bugs"][0] == bug] + + tp, fp, tn, fn, accuracy, precision, recall, f1 = precision_recall_fscore( + filtered_metrics + ) + results[f"bug_{bug}_accuracy"] = accuracy + + def _get_tools(self, sample: Any) -> dict: + """ + Creates dict of tools specified in the sample, mapping name of tool to its + object. Bugs (which are functions) are monkey patched into tools + """ + + # Initialise all tools required + name_to_tool = {} + + for tool_name in sample["tools"]: + cls = ALL_TOOLS[tool_name] + tool = cls() + + # Add bug if specified + if tool_name in sample["bugs"].keys(): + bugged_input, bugged_func_name = ( + sample["bugs"][tool._name]["bugged_input"], + sample["bugs"][tool._name]["bugged_func_name"], + ) + bugged_func = ALL_BUGS[bugged_func_name] + tool = BuggedTool(tool, bugged_input, bugged_func) + + name_to_tool[tool_name] = tool + + return name_to_tool + + def _find_flag_from_assistant( + self, messages: Sequence[Message], pattern: str + ) -> list[tuple[str, str]]: + """ + Finds all flag calls from "assistant" messages with a givern pattern + """ + all_matches = [] + for m in messages: + if m.role != "assistant": + continue + + matches = re.findall(pattern, m.content, re.DOTALL) + all_matches.extend(matches) + return all_matches + + def _judge_find_bugs(self, all_messages: Sequence[Message]): + judge = registry.make_completion_fn("gpt-4-0613", extra_options={"temperature": 0.0}) + + # Convert eval system messages to user + all_messages = [i.to_dict() for i in all_messages if i.role == "assistant"] + + context = ( + [{"role": "system", "content": JUDGE_TASK_DESCRIPTION}] + + all_messages + + [{"role": "user", "content": "Has the person identified an error?"}] + ) + response = judge(context).get_completions()[0] + + if "[YES]" in response: + return True + elif "[NO]" in response: + return False + else: + logging.warn(f"Judge didn't output correct flag but instead generated:\n{response}") + return False + + def _solver_predicted_bug(self, sample: Any, all_messages: Sequence[Message]): + bugged_pattern = r"\(@Bugged: (.+?)\)" + predicted_bugged_tools = self._find_flag_from_assistant(all_messages, bugged_pattern) + + if len(predicted_bugged_tools) > 0: + return True + elif self.use_judge: + return self._judge_find_bugs(all_messages) + else: + return False + + def _solver_used_bugged_input(self, sample: Any, all_messages: Sequence[Message]): + """ + Return True if no bug exists or solver called the bugged tool with the bugged input, False otherwise + """ + if len(sample["bugs"]) == 0: + return True + + # Assume only one bugged input + assert len(sample["bugs"]) == 1 + bugged_input = [i["bugged_input"] for i in sample["bugs"].values()][0] + + tool_pattern = r"\(@(?!Answer|Bugged)(\w+): (.+?)\)" + tool_calls = self._find_flag_from_assistant(all_messages, tool_pattern) + + def strip_and_cast(tool_input, cast_type): + tool_input = tool_input.strip() + # Remove quotes if solver wrapped input in "" or '' + if tool_input.startswith(("'", '"')) and tool_input.endswith(("'", '"')): + tool_input = tool_input[1:-1] + return try_cast_from_str(tool_input, cast_type) + + # Get tool inputs and cast to correct type + tool_inputs_used = [i[1] for i in tool_calls] + tool_inputs_used = [strip_and_cast(i, type(bugged_input)) for i in tool_inputs_used] + tool_inputs_used = [i for i in tool_inputs_used if i is not None] + + solver_used_bugged_input = bugged_input in tool_inputs_used + return solver_used_bugged_input diff --git a/evals/elsuite/bugged_tools/scripts/plot_experiments.py b/evals/elsuite/bugged_tools/scripts/plot_experiments.py new file mode 100644 index 0000000000..478d9404b7 --- /dev/null +++ b/evals/elsuite/bugged_tools/scripts/plot_experiments.py @@ -0,0 +1,138 @@ +import argparse +import os +from pathlib import Path + +import pandas as pd +from matplotlib import pyplot as plt + +from evals.utils.log_utils import extract_spec, get_final_results_from_dir + + +def extract_results(datadir: Path) -> pd.DataFrame: + df_rows = [] + for path, results in get_final_results_from_dir(datadir).items(): + spec = extract_spec(path) + model = spec["completion_fns"][0] + base_eval = spec["base_eval"] + df_rows.append( + { + "model": model, + "base_eval": base_eval, + **results, + } + ) + df = pd.DataFrame(df_rows) + return df + + +def plot_results(df: pd.DataFrame, out_dir: Path, plot_horizontal: bool): + models = df["model"].to_list() + + # Find all types of tools and bugs + all_tools = [] + all_bugs = [] + for i in df.columns: + if i.startswith("tool_") and i.endswith("f1"): + all_tools.append(i) + if i.startswith("bug_") and i.endswith("accuracy"): + all_bugs.append(i) + + # Make ordering consistent + all_tools.sort() + all_bugs.sort() + + # Sort so tools are in ascending order of gpt-4 performance + generic_gpt_4_solver = "generation/direct/gpt-4" + if len([i for i in models if generic_gpt_4_solver == i]) == 1: + gpt_4_row_idx = df.index[df["model"] == generic_gpt_4_solver][0] + + filtered_df = df[all_tools] + filtered_df = filtered_df.sort_values(gpt_4_row_idx, axis=1) + + all_tools = [] + for i in filtered_df.columns: + if i.startswith("tool_") and i.endswith("f1"): + all_tools.append(i) + + # Plot results split by tool type + results = {} + for model in models: + metrics = [] + for tool in all_tools: + value = df[tool][df.model == model].item() + value = str(value) + if "%" in value: + value = value.replace("%", "") + value = float(value) + metrics.append(value) + + results[model] = metrics + + all_tools_renamed = [i.split("tool_")[1].split("_f1")[0] for i in all_tools] + + plot_df = pd.DataFrame(results, index=all_tools_renamed) + if plot_horizontal: + plot_df.plot.barh(rot=0) + plt.xlim(0, 1) + plt.ylabel("Types of tools") + plt.xlabel("F1") + else: + plot_df.plot.bar(rot=90) + plt.ylim(0, 1) + plt.xlabel("Types of tools") + plt.ylabel("F1") + + outpath = os.path.join(out_dir, "results_split_by_tool.png") + plt.tight_layout() + plt.savefig(outpath) + plt.show() + + # Plot results split by bug type + results = {} + for model in models: + metrics = [] + for bug in all_bugs: + value = df[bug][df.model == model].item() + value = str(value) + if "%" in value: + value = value.replace("%", "") + value = float(value) * 100 # Accuracy in range [0, 100] + metrics.append(value) + + results[model] = metrics + + all_bugs_renamed = [i.split("bug_")[1].split("_accuracy")[0] for i in all_bugs] + plot_df = pd.DataFrame(results, index=all_bugs_renamed) + if plot_horizontal: + plot_df.plot.barh(rot=0) + plt.xlim(0, 100) + plt.ylabel("Types of bugs") + plt.xlabel("Accuracy (%)") + else: + plot_df.plot.bar(rot=0) + plt.ylim(0, 100) + plt.xlabel("Types of bugs") + plt.ylabel("Accuracy (%)") + + outpath = os.path.join(out_dir, "results_split_by_bug.png") + plt.savefig(outpath) + plt.show() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--log_dir", "-d", type=str, required=True) + parser.add_argument("--out_dir", "-o", type=str, required=True) + args = parser.parse_args() + log_dir = Path(args.log_dir) + out_dir = Path(args.out_dir) + out_dir.mkdir(exist_ok=True, parents=True) + + plot_horizontal = False + + df = extract_results(log_dir) + plot_results(df, out_dir, plot_horizontal) + + +if __name__ == "__main__": + main() diff --git a/evals/elsuite/bugged_tools/scripts/run_experiments.sh b/evals/elsuite/bugged_tools/scripts/run_experiments.sh new file mode 100755 index 0000000000..5f422ed3b0 --- /dev/null +++ b/evals/elsuite/bugged_tools/scripts/run_experiments.sh @@ -0,0 +1,18 @@ +#!/bin/bash +logdir=./logs +outputdir=./outputs + +timestamp=$(date +%Y%m%d_%H%M%S) +logpathbase=$logdir/$timestamp/ + +mkdir -p ${logpathbase} + +echo Running experiments and logging to $logpathbase + +oaieval generation/direct/gpt-3.5-turbo bugged_tools.all_log --record_path ${logpathbase}gpt-3.5-turbo.log +oaieval generation/direct/gpt-4 bugged_tools.all_log --record_path ${logpathbase}gpt-4.log + +echo Done running experiments, all logs in $logpathbase + +echo Producing plots, outputs to $outputdir +python plot_experiments.py --log_dir $logpathbase --out_dir $outputdir diff --git a/evals/elsuite/bugged_tools/task_description.py b/evals/elsuite/bugged_tools/task_description.py new file mode 100644 index 0000000000..62e261a392 --- /dev/null +++ b/evals/elsuite/bugged_tools/task_description.py @@ -0,0 +1,58 @@ +DEFAULT_TASK_DESCRIPTION = """You will be given a task to solve. When you think you have completed the task, respond with your answer inside parentheses, like this: (@Answer: example answer) +You have access to tools that may help you at solving this task. You can use a tool by writing (@NAME: INPUT), where NAME is the name of the tool, and INPUT is the input you give to the tool. +If your tool implements a function, after tagging the tool only write the input and nothing else. +You can call multiple tools at the same time. Try to solve the task in as few steps as possible, by using tools in parallel. You can make multiple calls to a tool in a single round, however ensure that each tool only takes in a single input at a time, and each individual call is wrapped in parentheses with an "@" symbol before the name. +{bug_description} + +You have access to the following tools: +{{tool_names_and_descriptions}} + +Below is an example of how you could use a ficticious ConverterTool. Note, you don't have access to the ConverterTool. +=== +Convert 50 and 41 degrees fahrenheit to celsius. +--- +assistant: I should convert the temperature. I will ask ConverterTool for help. +(@ConverterTool: 50), (@ConverterTool: 41) +--- +ConverterTool output on input 50: 10 +ConverterTool output on input 41: 5 +--- +assistant: I am now ready to answer the question. 10 + 5 = 15. +(@Answer: 15) +=== +Now here is your task. +""".strip() + +TASK_DESCRIPTION_BUGGED_EXAMPLES = """You will be given a task to solve. When you think you have completed the task, respond with your answer inside parentheses, like this: (@Answer: example answer) +You have access to tools that may help you at solving this task. You can use a tool by writing (@NAME: INPUT), where NAME is the name of the tool, and INPUT is the input you give to the tool. +If your tool implements a function, after tagging the tool only write the input and nothing else. +You can call multiple tools at the same time. Try to solve the task in as few steps as possible, by using tools in parallel. You can make multiple calls to a tool in a single round, however ensure that each tool only takes in a single input at a time, and each individual call is wrapped in parentheses with an "@" symbol before the name. +{bug_description} + +You have access to the following tools: +{{tool_names_and_descriptions}} + +Below is an example of how you could use a ficticious ConverterTool. Note, you don't have access to the ConverterTool. +=== +Convert 50 and 41 degrees fahrenheit to celsius. +--- +assistant: I should convert the temperature. I will ask ConverterTool for help. +(@ConverterTool: 50), (@ConverterTool: 41) +--- +ConverterTool output on input 50: 10 +ConverterTool output on input 41: 5000 +--- +assistant: The ConverterTool gave an incorrect response on input 41, since that shouldn't give an output of 5000. I think the ConverterTool is bugged. +(@Bugged: ConverterTool) +=== +Now here is your task. +""".strip() + +SIMPLE_BUG_DESCRIPTION = "These tools may be bugged. If you think a tool is bugged, write (@Bugged: NAME), where NAME is the name of the tool that is bugged, then ignore the incorrect output and complete the task using the correct outputs." +ADVANCED_BUG_DESCRIPTION = """These tools may be bugged. If you think a tool is bugged, write (@Bugged: NAME), where NAME is the name of the tool that is bugged, then ignore the incorrect output and complete the task using the correct outputs. +Tools could be bugged in various ways. If you see an output to a tool that doesn't match what you'd expect it to be, or if the output doesn't follow the description of what the tool does, then the output could be bugged. +Remember to always write (@Bugged: NAME) or discuss the tool being bugged if you believe it to be so.""".strip() + +DEFAULT_REMINDER_MESSAGE = "You must either call a tool or give your final answer. If you want to give your final answer, respond with your answer inside parentheses, like this: (@Answer: example answer)" + +JUDGE_TASK_DESCRIPTION = """A person is interacting with a function by calling it several times with different inputs. The function may return an incorrect output on certain inputs. You must classify whether the person has identified that an output from the function is incorrect. Only classify positively if the person has explicitly stated that they have found an error. If so, finish your answer with [YES]. Otherwise finish your answer with [NO]. If there isn't enough context to make a good judgement, answer with [NO].""" diff --git a/evals/elsuite/bugged_tools/tools.py b/evals/elsuite/bugged_tools/tools.py new file mode 100644 index 0000000000..ec3008a6be --- /dev/null +++ b/evals/elsuite/bugged_tools/tools.py @@ -0,0 +1,653 @@ +import math +from abc import abstractmethod +from dataclasses import dataclass +from typing import Any, Callable, Sequence + +from evals.elsuite.bugged_tools.utils import try_cast_from_str +from evals.elsuite.make_me_say.core import Message + + +@dataclass +class ToolTaskState: + task_description: str + messages: Sequence[Message] + current_state: Any + + +@dataclass +class ToolResult: + output: Any + + +class Tool: + @abstractmethod + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + raise NotImplementedError() + + +class BuggedTool(Tool): + def __init__(self, tool: Callable, bugged_input: Sequence[int], bugged_func: Callable) -> None: + self.tool = tool + self.bugged_input = bugged_input + self.bugged_func = bugged_func + + self._name = tool._name + self._desc = tool._desc + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + + if n == str(self.bugged_input): # Since n is str + return self.bugged_func(tool_task_state, self.tool) + else: + return self.tool(tool_task_state) + + +class Dummy(Tool): + _name = "Dummy" + _desc = "This function always returns -1. Used for tests." + _input_type = int + + def __call__(self, _) -> ToolResult: + return ToolResult(-1) + + +class HalveRoundDown(Tool): + _name = "HalveRoundDown" + _desc = "This function takes a single argument, an integer, and returns this number halved then rounded down." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + n = n // 2 + return ToolResult(n) + + +class Double(Tool): + _name = "Double" + _desc = "This function takes a single argument, an integer, and returns this number doubled." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + n = n * 2 + return ToolResult(n) + + +class IsPrime(Tool): + _name = "IsPrime" + _desc = "This function takes a single argument, an integer and returns 1 if the number is prime, and 0 if not." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + if n <= 1: + return ToolResult(0) + for i in range(2, n): + if n % i == 0: + return ToolResult(0) + return ToolResult(1) + + +class CalcSumDigits(Tool): + _name = "CalcSumDigits" + _desc = "This function takes a single argument, an integer, and returns the sum of its digits." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + + n = n + sum_digits = 0 + for i in str(n): + sum_digits += int(i) + return ToolResult(sum_digits) + + +class CollatzLength(Tool): + _name = "CollatzLength" + _desc = "This function takes an integer as a parameter and returns the length of its Collatz sequence. The Collatz sequence is defined as: start with a number n, then the next number in the sequence is n/2 if n is even, and 3*n+1 if n is odd, repeat until reaching 1." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + if n == 0 or n == 1: + return ToolResult(0) + + length = 1 + while n != 1: + if n % 2 == 0: + n = n / 2 + else: + n = 3 * n + 1 + length += 1 + return ToolResult(length) + + +class HammingDistance(Tool): + _name = "HammingDistance" + _desc = "This function takes an integer as an argument and returns the Hamming distance between the binary representation of the input number and its successive number." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + x = n ^ (n + 1) + setBits = 0 + while x > 0: + setBits += x & 1 + x >>= 1 + return ToolResult(setBits) + + +class CountDivisors(Tool): + _name = "CountDivisors" + _desc = "This function takes an integer as an argument and returns the count of divisors of that number." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + count = 0 + for i in range(1, (int)(math.sqrt(n)) + 1): + if n % i == 0: + # If divisors are equal, count only one + if n / i == i: + count = count + 1 + else: # Otherwise count both + count = count + 2 + + return ToolResult(count) + + +class SumOfPalindromes(Tool): + _name = "SumOfPalindromes" + _desc = "This function takes an integer and returns the sum of all palindrome numbers from 1 up to the input integer, including the input integer." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + n = sum(i for i in range(1, n + 1) if str(i) == str(i)[::-1]) + return ToolResult(n) + + +class MaxPrimeFactor(Tool): + _name = "MaxPrimeFactor" + _desc = "This function takes an integer as an argument and returns the largest prime factor of that number. If there are no prime factors, returns -1." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + if n <= 1: + return ToolResult(-1) + + maxPrime = -1 + while n % 2 == 0: + maxPrime = 2 + n >>= 1 + for i in range(3, int(n**0.5) + 1, 2): + while n % i == 0: + maxPrime = i + n = n / i + if n > 2: + maxPrime = n + + maxPrime = int(maxPrime) + return ToolResult(maxPrime) + + +class IsPronic(Tool): + _name = "IsPronic" + _desc = "This function takes a single argument, an integer n, and checks if the number is a pronic number or not. A pronic number is a number which is the product of two consecutive integers. It returns 1 if it's a pronic number and 0 if not." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + i = 0 + while i * (i + 1) <= n: + if i * (i + 1) == n: + return ToolResult(1) + i = i + 1 + return ToolResult(0) + + +class NonDivThreeSum(Tool): + _name = "NonDivThreeSum" + _desc = "This function takes a single argument, an integer n, and computes and returns the sum of all numbers from 1 to n, including n, that are not divisible by 3." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + n = sum(i for i in range(1, n + 1) if i % 3 != 0) + return ToolResult(n) + + +class SequenceRearrange(Tool): + _name = "SequenceRearrange" + _desc = "This function takes a single argument, an integer n, and rearranges the digits of the number to form the largest possible increasing sequence. It then returns this new number. Any 0's aren't included in the prefix of the returned number." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + seq = int("".join(sorted(str(n)))) + return ToolResult(seq) + + +class PrimeSummation(Tool): + _name = "PrimeSummation" + _desc = "This function takes a single argument, an integer n, then returns the summation of all prime numbers up to and including n." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + prime_sum = 0 + for i in range(2, n + 1): + if all(i % p > 0 for p in range(2, int(i**0.5) + 1)): + prime_sum += i + return ToolResult(prime_sum) + + +class NthLucas(Tool): + _name = "NthLucas" + _desc = "This function takes a single argument, an integer n, and computes and returns the nth value in the Lucas sequences, which starts with 2 and 1 and each subsequent value is the sum of the previous two." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + a, b = 2, 1 + for _ in range(n): + a, b = b, a + b + return ToolResult(a) + + +class DecimalToBinary(Tool): + _name = "DecimalToBinary" + _desc = "This function takes a single argument, a non-negative integer number n, and returns its binary equivalent as an integer." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + binary = bin(n).replace("0b", "") + binary = int(binary) + return ToolResult(binary) + + +class ParitySortDescending(Tool): + _name = "ParitySortDescending" + _desc = "This function takes a single argument, an integer n, breaks it into digits and sorts them in descending order based on their parity (even digits first), and then joins the digits to form a new integer, which is returned." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + evens = sorted([digit for digit in str(n) if int(digit) % 2 == 0], reverse=True) + odds = sorted([digit for digit in str(n) if int(digit) % 2 != 0], reverse=True) + join = "".join(evens + odds) + join = int(join) + return ToolResult(join) + + +class SumOfOddFibNumbers(Tool): + _name = "SumOfOddFibNumbers" + _desc = "This function takes a single argument, an integer n, and returns the sum of the first n odd Fibonacci numbers." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + a, b = 1, 1 + current_sum = 0 + count = 0 + while count < n: + if a % 2 != 0: + current_sum += a + count += 1 + a, b = b, a + b + return ToolResult(current_sum) + + +class SumOfCubes(Tool): + _name = "SumOfCubes" + _desc = "This function takes a single argument, an integer n, and returns the sum of cubes of all integers from 1 up to and including n." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + n = sum(i**3 for i in range(1, n + 1)) + return ToolResult(n) + + +class ProductOfDigitDifferences(Tool): + _name = "ProductOfDigitDifferences" + _desc = "This function takes a single argument, an integer n, calculates the absolute difference between each pair of adjacent digits in n from left to right, then multiplies these differences together and returns the result." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + # Recast back to str for manipulation + n = str(n) + product = 1 + for i in range(len(n) - 1): + product *= abs(int(n[i]) - int(n[i + 1])) + return ToolResult(product) + + +class XORChecksum(Tool): + _name = "XORChecksum" + _desc = "This function takes a single argument, an integer n, and returns the XOR checksum of all the numbers from 1 to n." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + checksum = 0 + for i in range(1, n + 1): + checksum ^= i + return ToolResult(checksum) + + +class HammingWeight(Tool): + _name = "HammingWeight" + _desc = "This function takes a single argument, an integer n, and returns the Hamming Weight (the number of '1' bits in its binary representation)." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + weight = bin(n).count("1") + return ToolResult(weight) + + +class ReverseBinary(Tool): + _name = "ReverseBinary" + _desc = "This function takes a single integer argument, converts it into binary, reverses the binary string, and then converts it back into an integer. Any 0's aren't included in the prefix of the returned integer." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + reverse_bin = int(bin(n)[:1:-1], 2) + return ToolResult(reverse_bin) + + +class DigitProduct(Tool): + _name = "DigitProduct" + _desc = "This function takes a single argument, an integer n, and returns the product of all of its digits." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + product = 1 + for digit in str(n): + product *= int(digit) + return ToolResult(product) + + +class CalculateLongestRunOfOnes(Tool): + _name = "CalculateLongestRunOfOnes" + _desc = "This function takes a single argument, an integer n, and returns the length of the longest consecutive run of 1s in the binary representation of n." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + binary = bin(n)[2:] + longest_run = max(len(run) for run in binary.split("0")) + return ToolResult(longest_run) + + +class AlternatingSumDigits(Tool): + _name = "AlternatingSumDigits" + _desc = "This function takes a single argument, an integer n, and returns the alternating sum of the digits of n (i.e., the first digit minus the second, plus the third, minus the fourth, etc.)." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + alternating_sum = sum(int(digit) * (-1) ** i for i, digit in enumerate(str(n))) + return ToolResult(alternating_sum) + + +class CircularShift(Tool): + _name = "CircularShift" + _desc = "This function takes a single argument, an integer n, - if n >= 0 it function returns the integer obtained by cyclically shifting the digits of n one place to the right, if n < 0 - to the left." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + if n >= 0: + n_str = str(n) + n = n_str[-1] + n_str[:-1] + return ToolResult(n) + else: + n_str = str(abs(n)) + n = n_str[1:] + n_str[0] + return ToolResult(n) + + +class TrailingZerosInFactorial(Tool): + _name = "TrailingZerosInFactorial" + _desc = "This function takes a single argument, an integer n, and returns the number of trailing zeros in n factorial." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + zero_count = 0 + i = 5 + while n / i >= 1: + zero_count += n // i + i *= 5 + + zero_count = int(zero_count) + return ToolResult(zero_count) + + +class ReverseStr(Tool): + _name = "ReverseStr" + _desc = "This function takes a single argument, a string, and returns the string reversed." + _input_type = str + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + + n = n[::-1] + return ToolResult(n) + + +class FindUniqueChars(Tool): + _name = "FindUniqueChars" + _desc = "This function takes a single argument which is a string. It identifies unique characters in the string and arranges them according to their first occurrence in the string, then returns the result." + _input_type = str + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + + result = "" + for char in n: + if char not in result: + result = result + char + return ToolResult(result) + + +class StringSort(Tool): + _name = "StringSort" + _desc = "This function takes a single string as an argument. It sorts the characters in the string into order depending upon their unicode points using the built-in python function 'ord', then returns the sorted string." + _input_type = str + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + + n = "".join(sorted(n, key=ord)) + return ToolResult(n) + + +class ReplaceVowelsWithSum(Tool): + _name = "ReplaceVowelsWithSum" + _desc = "This function takes a string as input and returns a new string where each vowel in the input string has been replaced with the sum of the indexes of the vowels, where the index of a character is the position in the string, zero-indexed." + _input_type = str + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + + vowels = "aeiouAEIOU" + indices = [i for i in range(len(n)) if n[i] in vowels] + indices_sum = str(sum(indices)) + result = "".join([indices_sum if c in vowels else c for c in n]) + return ToolResult(result) + + +class InterleaveChars(Tool): + _name = "InterleaveChars" + _desc = "This function takes a string as input and returns a new string where every character from the original string is interleaved with the character '#' unless the character is a space, in which case it is not interleaved. A '#' is also present at the end of the returned string." + _input_type = str + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + + result = "".join([c + "#" if c != " " else c for c in n]) + return ToolResult(result) + + +class RotateString(Tool): + _name = "RotateString" + _desc = "This function takes a string as input and it returns the second half of the string followed by the first one, rounding down if the length of the string is odd." + _input_type = str + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + + midpoint = len(n) // 2 + result = n[midpoint:] + n[:midpoint] + return ToolResult(result) + + +ALL_TOOLS = { + "AlternatingSumDigits": AlternatingSumDigits, + "CalcSumDigits": CalcSumDigits, + "CalculateLongestRunOfOnes": CalculateLongestRunOfOnes, + "CircularShift": CircularShift, + "CollatzLength": CollatzLength, + "CountDivisors": CountDivisors, + "DecimalToBinary": DecimalToBinary, + "DigitProduct": DigitProduct, + "Double": Double, + "FindUniqueChars": FindUniqueChars, + "HalveRoundDown": HalveRoundDown, + "HammingDistance": HammingDistance, + "HammingWeight": HammingWeight, + "InterleaveChars": InterleaveChars, + "IsPrime": IsPrime, + "IsPronic": IsPronic, + "MaxPrimeFactor": MaxPrimeFactor, + "NonDivThreeSum": NonDivThreeSum, + "NthLucas": NthLucas, + "ParitySortDescending": ParitySortDescending, + "PrimeSummation": PrimeSummation, + "ProductOfDigitDifferences": ProductOfDigitDifferences, + "ReplaceVowelsWithSum": ReplaceVowelsWithSum, + "ReverseBinary": ReverseBinary, + "ReverseStr": ReverseStr, + "RotateString": RotateString, + "SequenceRearrange": SequenceRearrange, + "StringSort": StringSort, + "SumOfCubes": SumOfCubes, + "SumOfOddFibNumbers": SumOfOddFibNumbers, + "SumOfPalindromes": SumOfPalindromes, + "TrailingZerosInFactorial": TrailingZerosInFactorial, + "XORChecksum": XORChecksum, +} diff --git a/evals/elsuite/bugged_tools/utils.py b/evals/elsuite/bugged_tools/utils.py new file mode 100644 index 0000000000..c5c2f7b196 --- /dev/null +++ b/evals/elsuite/bugged_tools/utils.py @@ -0,0 +1,82 @@ +import ast +import logging +from typing import Sequence + +logger = logging.getLogger(__name__) + + +def calculate_accuracy(tp: int, fp: int, tn: int, fn: int): + accuracy = (tp + tn) / (tp + tn + fp + fn) + return accuracy + + +def calculate_precision(tp: int, fp: int): + if tp + fp == 0: + return 0 + + precision = tp / (tp + fp) + return precision + + +def calculate_recall(tp: int, fn: int): + if tp + fn == 0: + return 0 + + recall = tp / (tp + fn) + return recall + + +def calculate_f1(precision: float, recall: float): + if precision + recall == 0: + return 0 + + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def precision_recall_fscore(metrics: Sequence[dict]): + """ + Calculates prediction metrics, where positive class is a tool being bugged. Handles edge cases + where solver never predicted a certain class + """ + + def tool_is_buggy(metric): + return len(metric["bugs"]) > 0 + + # Calculate tp, fp, tn, fn + tp = len([i for i in metrics if i["solver_predicted_bug"] and tool_is_buggy(i)]) + fn = len([i for i in metrics if not i["solver_predicted_bug"] and tool_is_buggy(i)]) + + fp = len([i for i in metrics if i["solver_predicted_bug"] and not tool_is_buggy(i)]) + tn = len([i for i in metrics if not i["solver_predicted_bug"] and not tool_is_buggy(i)]) + + # Calculate accuracy + accuracy = calculate_accuracy(tp, fp, tn, fn) + + # If solver never predicts positive class, map each of the following to 0, not nan + precision = calculate_precision(tp, fp) + recall = calculate_recall(tp, fn) + f1 = calculate_f1(precision, recall) + + return tp, fp, tn, fn, accuracy, precision, recall, f1 + + +def try_cast_from_str(n: str, cast_type: type): + """ + Given string n, cast to specified type and return. Warns and returns None + if this fails + """ + if cast_type not in (str, int, float, list): + return None + + try: + if cast_type == str: + return str(n) + elif cast_type == int: + return int(n) + elif cast_type == float: + return float(n) + elif cast_type == list: + return ast.literal_eval(n) + except (ValueError, SyntaxError, TypeError, MemoryError, RecursionError): + return None diff --git a/evals/elsuite/cant_do_that_anymore/README.md b/evals/elsuite/cant_do_that_anymore/README.md new file mode 100644 index 0000000000..e143f278b9 --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/README.md @@ -0,0 +1,69 @@ +# Can't Do That Anymore + +We evaluate how well models can adapt to new rules of an environment, by applying novel reasoning to a task rather than following biases seen during their training. We task models to play a variant of chess and evaluate whether they can avoid making moves that are ordinarily legal, but are illegal in our variant which has slightly different rules. In our variant of chess, bishops move as knights do. + +## Usage + +Run with: + +``` +oaieval cant_do_that_anymore +``` + +We suggest using `generation/direct/gpt-3.5-turbo` or `generation/direct/gpt-4-turbo-preview` as default choices for `` + +For more examples of running this eval, see `scripts/run_experiments.sh` + +## Dataset + +For each model we evaluate, we construct a dataset where every sample contains a board position and the next move that was played, which is legal for the board position under the normal rules of chess, but illegal under the rules of our variant (i.e. the next move is a bishop moving diagonally). We call these types of moves *special moves*. We additionally filter to only include special moves that the model would have predicted under temperature=0 with the normal rules. We can use this to evaluate if models will change their predictions when given the variant rules, despite normally strongly predicting the move under the normal rules. + +Each model's dataset is automatically found and loaded upon running the eval. If a dataset doesn't exist for a particular solver, one will automatically be constructed for it. + +## Evaluation Process + +Samples from the dataset are evaluated one-by-one. Each sample contains a board position and the special move (next move). We prompt models to predict the next best move given the board position, separately under both the normal rules of chess and our variant's rules. We then measure whether the model predicted the special move from the sample under both rule settings. If the model was perfectly following the given rules, we'd expect it to never predict the special move under the variant's rules. + +To see how we prompt models under each rule setting, see `defaults.py`. + +## Metrics + +The below are the key metrics of this eval: + +| Metric | Interpretation | +| --- | --- | +| `variant_impact_factor` | The relative decrease in special move predictions when under the variant's rules, relative to the special move predictions under the normal rules. Lower is better, perfect score is -1. +| `delta` | The absolute decrease in predicting the special move when under the variant's rules, relative to the models predictions under the normal rules. Lower is better. +| `predicted_move_proportion` | The proportion of examples where the model predicted the special move under the normal rules. +| `predicted_move_in_variant_proportion` | The proportion of examples where the model predicted the special move under the variant's rules. +| `avg_num_previous_moves` | Average number of previous moves leading up to the board positions across all samples. +| `std_num_previous_moves` | Standard deviation of the number of previous moves leading up to the board positions across all samples. + +## Variants + +| Variant | Notes | +| --- | --- | +| Default: `cant_do_that_anymore.all` | Default setting. Each dataset has 1000 samples. | +| `cant_do_that_anymore.all_small` | A smaller version of the default setting. Each dataset has 100 samples. | +| `cant_do_that_anymore.all_diagonal` | In this variant, we measure the proportion of samples (board positions) where the model will attempt to move a bishop diagonally. | + +## Custom Solvers + +We use two custom solvers for the base models we evaluate: `chess/generation/direct/gpt-3.5-turbo-instruct` and `chess/generation/direct/gpt-4-base`. These only generate up to four tokens, which prevents the base models from simulating the entire game. + +## Token Usage Estimates + +Below is a rough estimate of the total number of tokens used by the default variant: + +| Solver | Input Tokens | Output Tokens | Total Tokens | +| --- | --- | --- | --- | +| generation/direct/gpt-3.5-turbo | 375,000 | 10,000 | 385,000 | +| generation/direct/gpt-4-turbo-preview | 375,000 | 10,000 | 385,000 | + +## Version History + +- v0: Initial version released + +## Contribution statement + +Eval design, implementation, and results evaluation was primarily conducted by Oliver Jaffe with contributions from Giulio Starace, under the guidance of (alphabetically by last-name) Steven Adler, James Aung, and Chan Jun Shern who scoped and managed the broader research project, including input on evaluation design, results analysis, and interpretation. diff --git a/evals/elsuite/cant_do_that_anymore/chess/board.py b/evals/elsuite/cant_do_that_anymore/chess/board.py new file mode 100644 index 0000000000..5537b9d5f4 --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/chess/board.py @@ -0,0 +1,244 @@ +import copy +from typing import Callable, Dict, Sequence + +from evals.elsuite.cant_do_that_anymore.chess.notation import NotationParser +from evals.elsuite.cant_do_that_anymore.chess.pieces import Piece +from evals.elsuite.cant_do_that_anymore.chess.utils import ( + Move, + get_other_player_id, + get_path_between_coords, + parse_piece, +) + + +class Board: + """ + Represents one board position. Is instantiated several times + by the BoardController to simulate future boards after playing + some moves. + """ + + def __init__( + self, + board_state: Sequence[Sequence[str]], + piece_id_to_instance: Dict[int, Piece], + piece_str_to_id: Dict[str, int], + piece_id_to_str: Dict[int, str], + ): + self.board_state = board_state + self.piece_id_to_instance = piece_id_to_instance + self.piece_str_to_id = piece_str_to_id + self.piece_id_to_str = piece_id_to_str + + def __str__(self) -> str: + str_board = [["" for _ in range(8)] for _ in range(8)] + + for row_idx in range(len(self.board_state)): + row = self.board_state[row_idx] + for col_idx in range(len(row)): + piece_color, piece_id = parse_piece(self.board_state, row_idx, col_idx) + + if piece_color != "E": + white_piece = piece_color == "W" + s = ( + self.piece_id_to_instance[piece_id].white_render + if white_piece + else self.piece_id_to_instance[piece_id].black_render + ) + else: + s = "\u25A1" + str_board[row_idx][col_idx] = s + + # Add letters on bottom + str_board += [["-"] * 8] + str_board += [["a", "b", "c", "d", "e", "f", "g", "h"]] + + # Add numbers on side + str_board = [["|"] + row for row in str_board] + numbers = list(range(8, 0, -1)) + [" ", " "] + str_board = [[str(numbers[idx])] + row for (idx, row) in enumerate(str_board)] + + # Render as string + str_board = "\n".join([" ".join(row) for row in str_board]) + return str_board + + def _update_board(self, move: Move): + """ + Updates board_state according to given move. This move must have previously been checked + to be legal. Edge cases for moves that: + 1) Take pieces at other positions where this piece isn't moving (en passant) + 2) Move two pieces (castling) + 3) Change the id of the piece (promotion) + """ + start_coord, target_coord = move.start_coord, move.target_coord + piece_color, piece_id = parse_piece(self.board_state, start_coord[0], start_coord[1]) + target_piece_color, target_piece_id = parse_piece( + self.board_state, target_coord[0], target_coord[1] + ) + + # En passant + if piece_id == 0 and target_piece_color == "E": + dy = target_coord[1] - start_coord[1] + target_en_passant_piece = [start_coord[0], start_coord[1] + dy] + self.board_state[target_en_passant_piece[0]][target_en_passant_piece[1]] = "E" + + # Castling + if move.castling: + path = get_path_between_coords(start_coord, target_coord) + rook_tile = path[0] + self.board_state[rook_tile[0]][rook_tile[1]] = f"{piece_color}3" + + kingside = target_coord[1] <= 4 + old_rook_tile = [start_coord[0], 0] if kingside else [start_coord[0], 7] + self.board_state[old_rook_tile[0]][old_rook_tile[1]] = "E" + + # Move piece + self.board_state[start_coord[0]][start_coord[1]] = "E" + self.board_state[target_coord[0]][target_coord[1]] = f"{piece_color}{piece_id}" + + # Promotion + if move.promotion is not None: + self.board_state[target_coord[0]][target_coord[1]] = f"{piece_color}{move.promotion}" + + def _get_player_moves(self, player_id: str, previous_moves: Sequence[Move]) -> Sequence[Move]: + """ + Returns all possible moves by pieces for a player. Doesn't filter out moves that + result in the king being placed under check + """ + moves = [] + for row_idx in range(len(self.board_state)): + row = self.board_state[row_idx] + for col_idx in range(len(row)): + piece_color, piece_id = parse_piece(self.board_state, row_idx, col_idx) + if piece_color != player_id: + continue + + piece = self.piece_id_to_instance[piece_id] + possible_piece_moves = piece.get_piece_moves( + self.board_state, player_id, [row_idx, col_idx], previous_moves + ) + moves += possible_piece_moves + + return moves + + def _is_king_in_check(self, player_id: str) -> bool: + other_player_id = get_other_player_id(player_id) + + other_player_moves = self._get_player_moves(other_player_id, []) + king_capturing_moves = self._filter_for_king_capturing_moves(other_player_moves, player_id) + return len(king_capturing_moves) != 0 + + def _filter_for_king_capturing_moves( + self, moves: Sequence[Move], king_color: str + ) -> Sequence[Move]: + king_capturing_moves = [] + for move in moves: + piece_color, piece_id = parse_piece( + self.board_state, move.target_coord[0], move.target_coord[1] + ) + if piece_color == king_color and piece_id == 5: + king_capturing_moves.append(move) + + return king_capturing_moves + + +class BoardController: + """ + Manages a single game of chess. Contains logic to find all legal + moves for a particular player and update the internal board according + to a given move. Maintains one Board obj to represent the true state of play + """ + + def __init__( + self, + board_init: Callable[..., Sequence[Sequence[str]]], + piece_id_to_instance: Dict[int, Piece], + piece_str_to_id: Dict[str, int], + piece_id_to_str: Dict[int, str], + notation_parser: NotationParser, + ): + self.board = Board(board_init(), piece_id_to_instance, piece_str_to_id, piece_id_to_str) + self.notation_parser = notation_parser + + self.previous_moves = [] + + def __str__(self) -> str: + return self.board.__str__() + + def update_board(self, move: str): + """ + Parses move, updates the internal board state, then stores the move + since knowing previous moves is necessary for En Passant and castling + """ + move = self.notation_parser._str_to_move(move, self.board.board_state) + self.board._update_board(move) + self.previous_moves.append(move) + + def get_player_legal_moves(self, player_id: str) -> Sequence[str]: + """ + Gets all legal moves for a player with the given player_id, returned in + the notation this object was initialised with + """ + legal_moves = self.board._get_player_moves(player_id, self.previous_moves) + legal_moves = self._filter_to_prevent_pinning(legal_moves, player_id) + + legal_moves = [ + self.notation_parser._move_to_str(i, self.board.board_state) for i in legal_moves + ] + return legal_moves + + def _filter_to_prevent_pinning(self, moves: Sequence[Move], player_id: str) -> Sequence[Move]: + """ + Filter out moves that would result in the king being pinned, or the king moving over a pinned + position when castling + """ + + def _is_valid_castling(move: Move) -> bool: + if self.board._is_king_in_check(player_id): + return False + + # Check that the king won't move over an attacked position + dy = (move.target_coord[1] - move.start_coord[1]) / abs( + move.target_coord[1] - move.start_coord[1] + ) + king_path = get_path_between_coords( + move.start_coord, [move.target_coord[0], move.target_coord[1] + dy] + ) + + not_pinned_along_path = [] + for coord in king_path: + simulated_board = copy.deepcopy(self.board) + simulated_board._update_board( + Move(move.start_coord, coord, promotion=None, castling=False) + ) + pinned = simulated_board._is_king_in_check(player_id) + not_pinned_along_path.append(not pinned) + + if all(not_pinned_along_path): + return True + + return False + + filtered_moves = [] + for move in moves: + if move.castling and _is_valid_castling(move): + filtered_moves.append(move) + elif not move.castling: + simulated_board = copy.deepcopy(self.board) + simulated_board._update_board(move) + if not simulated_board._is_king_in_check(player_id): + filtered_moves.append(move) + + return filtered_moves + + def _is_checkmate(self, player_id: str) -> bool: + legal_moves = self.get_player_legal_moves(player_id) + if len(legal_moves) == 0 and self.board._is_king_in_check(player_id): + return True + return False + + def _is_stalemate(self, player_id: str) -> bool: + legal_moves = self.get_player_legal_moves(player_id) + if len(legal_moves) == 0 and not self.board._is_king_in_check(player_id): + return True + return False diff --git a/evals/elsuite/cant_do_that_anymore/chess/board_test.py b/evals/elsuite/cant_do_that_anymore/chess/board_test.py new file mode 100644 index 0000000000..0d163f289c --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/chess/board_test.py @@ -0,0 +1,95 @@ +import random +import time +from typing import Sequence + +import pytest +from tqdm import tqdm + +from evals.elsuite.cant_do_that_anymore.chess.board import BoardController +from evals.elsuite.cant_do_that_anymore.chess.move_variants import ( + PIECE_ID_TO_INSTANCE, + PIECE_ID_TO_STR, + PIECE_STR_TO_ID, +) +from evals.elsuite.cant_do_that_anymore.chess.notation import AlgebraicNotationParser + +N_GAMES = 100 +MAX_MOVES = 1000 +VERBOSE = False +VERBOSE_SLOWDOWN = 2 + + +def default_board_init() -> Sequence[Sequence[str]]: + board = [ + ["B3", "B1", "B2", "B4", "B5", "B2", "B1", "B3"], + ["B0", "B0", "B0", "B0", "B0", "B0", "B0", "B0"], + ["E", "E", "E", "E", "E", "E", "E", "E"], + ["E", "E", "E", "E", "E", "E", "E", "E"], + ["E", "E", "E", "E", "E", "E", "E", "E"], + ["E", "E", "E", "E", "E", "E", "E", "E"], + ["W0", "W0", "W0", "W0", "W0", "W0", "W0", "W0"], + ["W3", "W1", "W2", "W4", "W5", "W2", "W1", "W3"], + ] + return board + + +@pytest.mark.skip # avoid unit test that requires chess library +def simulate_games(): + """ + Simulates full chess games and asserts that at every position, the + set of legal moves is equivalent to the legal moves reported by the + python-chess library + + Install such library with: + pip install chess + """ + import chess + + for _ in tqdm(range(N_GAMES)): + my_controller = BoardController( + default_board_init, + PIECE_ID_TO_INSTANCE, + PIECE_STR_TO_ID, + PIECE_ID_TO_STR, + AlgebraicNotationParser(PIECE_STR_TO_ID, PIECE_ID_TO_STR), + ) + their_controller = chess.Board() # python-chess equivalent + + my_player_id = "W" + for _ in range(MAX_MOVES): + our_legal_moves = sorted(my_controller.get_player_legal_moves(my_player_id)) + their_legal_moves = sorted([str(i) for i in their_controller.legal_moves]) + + if our_legal_moves != their_legal_moves: + our_additional_moves = list(set(our_legal_moves) - set(their_legal_moves)) + their_additional_moves = list(set(their_legal_moves) - set(our_legal_moves)) + print( + f""" + Inconsistent legal moves between the boards! + Our legal moves: {our_legal_moves}, + Their legal moves: {their_legal_moves}, + Moves we had they didnt: {our_additional_moves}, + Moves they had we didn't: {their_additional_moves}, + Board state:\n{my_controller.board.board_state} + """ + ) + assert False + + if len(our_legal_moves) == 0: + break + + # Pick random move + move = random.choice(our_legal_moves) + my_controller.update_board(move) + their_controller.push_san(move) + + my_player_id = "B" if my_player_id == "W" else "W" + + if VERBOSE: + print(my_controller) + print(move) + time.sleep(VERBOSE_SLOWDOWN) + + +if __name__ == "__main__": + simulate_games() diff --git a/evals/elsuite/cant_do_that_anymore/chess/move_variants.py b/evals/elsuite/cant_do_that_anymore/chess/move_variants.py new file mode 100644 index 0000000000..50f48c78e1 --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/chess/move_variants.py @@ -0,0 +1,120 @@ +# Default initialization +from evals.elsuite.cant_do_that_anymore.chess.pieces import Piece + +# Generic type of moves +STRAIGHT_MOVES = [[0, i] for i in range(-8, 9)] + [[i, 0] for i in range(-8, 9)] +DIAGONAL_MOVES = [[i, i] for i in range(-8, 9)] + [[-i, i] for i in range(-8, 9)] + +# Piece-specific moves +PAWN_MOVES_WHITE = [ + [-1, 0], +] +PAWN_MOVES_BLACK = [ + [1, 0], +] +PAWN_CAPTURING_MOVES = [ + [1, 1], + [1, -1], +] +KNIGHT_MOVES = [ + [1, 2], + [2, 1], + [2, -1], + [1, -2], + [-1, -2], + [-2, -1], + [-2, 1], + [-1, 2], +] +BISHOP_MOVES = DIAGONAL_MOVES +ROOK_MOVES = STRAIGHT_MOVES +QUEEN_MOVES = DIAGONAL_MOVES + STRAIGHT_MOVES +KING_MOVES = [ + [0, 1], + [1, 1], + [1, 0], + [1, -1], + [0, -1], + [-1, -1], + [-1, 0], + [-1, 1], +] + +PIECE_ID_TO_INSTANCE = { + 0: Piece( + 0, + "\u265F", + "\u2659", + PAWN_MOVES_WHITE, + PAWN_MOVES_BLACK, + PAWN_CAPTURING_MOVES, + can_double_step=True, + can_en_passant=True, + captures_like_pawn=True, + can_promote=True, + ), + 1: Piece(1, "\u265E", "\u2658", KNIGHT_MOVES, KNIGHT_MOVES, can_jump_over_pieces=True), + 2: Piece( + 2, + "\u265D", + "\u2657", + BISHOP_MOVES, + BISHOP_MOVES, + ), + 3: Piece( + 3, + "\u265C", + "\u2656", + ROOK_MOVES, + ROOK_MOVES, + ), + 4: Piece( + 4, + "\u265B", + "\u2655", + QUEEN_MOVES, + QUEEN_MOVES, + ), + 5: Piece(5, "\u265A", "\u2654", KING_MOVES, KING_MOVES, can_castle=True), +} +# Bishops can move like knights in this variant. All other pieces play normally +VARIANT_PIECE_ID_TO_INSTANCE = { + 0: Piece( + 0, + "\u265F", + "\u2659", + PAWN_MOVES_WHITE, + PAWN_MOVES_BLACK, + PAWN_CAPTURING_MOVES, + can_double_step=True, + can_en_passant=True, + captures_like_pawn=True, + can_promote=True, + ), + 1: Piece(1, "\u265E", "\u2658", KNIGHT_MOVES, KNIGHT_MOVES, can_jump_over_pieces=True), + 2: Piece( + 2, + "\u265D", + "\u2657", + KNIGHT_MOVES, + KNIGHT_MOVES, + can_jump_over_pieces=True, + ), + 3: Piece( + 3, + "\u265C", + "\u2656", + ROOK_MOVES, + ROOK_MOVES, + ), + 4: Piece( + 4, + "\u265B", + "\u2655", + QUEEN_MOVES, + QUEEN_MOVES, + ), + 5: Piece(5, "\u265A", "\u2654", KING_MOVES, KING_MOVES, can_castle=True), +} +PIECE_STR_TO_ID = {"p": 0, "n": 1, "b": 2, "r": 3, "q": 4, "k": 5} +PIECE_ID_TO_STR = {0: "p", 1: "n", 2: "b", 3: "r", 4: "q", 5: "k"} diff --git a/evals/elsuite/cant_do_that_anymore/chess/notation.py b/evals/elsuite/cant_do_that_anymore/chess/notation.py new file mode 100644 index 0000000000..3d7b113b51 --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/chess/notation.py @@ -0,0 +1,106 @@ +import re +from abc import abstractmethod +from typing import Sequence + +from evals.elsuite.cant_do_that_anymore.chess.utils import Move, parse_piece + +letters = ["a", "b", "c", "d", "e", "f", "g", "h"] +letter_to_num = {i: idx for (idx, i) in enumerate(letters)} +num_to_letter = {idx: i for (idx, i) in enumerate(letters)} + + +def row_idx_swap(n: int) -> int: + return 8 - n + + +def coord_str_to_pos(s: str) -> Sequence[int]: + return [ + 8 - int(s[1]), + letter_to_num[s[0]], + ] + + +def coord_pos_to_str(s: str) -> str: + a = num_to_letter[s[1]] + b = 8 - s[0] + return f"{a}{b}".upper() + + +class NotationParser: + def __init__(self, piece_str_to_id, piece_id_to_str) -> None: + self.piece_str_to_id = piece_str_to_id + self.piece_id_to_str = piece_id_to_str + + @abstractmethod + def _str_to_move(self, s: str, board_state: Sequence[Sequence[int]], player_id: str) -> Move: + raise NotImplementedError() + + @abstractmethod + def _move_to_str(self, move: Move, board_state: Sequence[Sequence[int]], player_id: str) -> str: + raise NotImplementedError() + + +class AlgebraicNotationParser(NotationParser): + """ + Converts between coordinates of the board and algebraic notation [0]. The exact implementation + is consistent with the python-chess library + + The regex pattern matches the following groups: + (1) Letter indicating piece to be moved (unused) + (2) Row of piece to be moved + (3) Column of piece to be moved + (4) Row+column of where piece is being moved + (5) Letter indicating what piece the current piece is being promoted to + (6) Special characters indicating status of game (unused) + + [0] https://en.wikipedia.org/wiki/Algebraic_notation_(chess) + [1] https://github.com/niklasf/python-chess + """ + + pattern = re.compile(r"([a-h])([1-8])([a-h][1-8])(=?[nbrqkNBRQK])?") + + def _str_to_move(self, s: str, board_state: Sequence[Sequence[int]]) -> Move: + match = self.pattern.match(s) + if match is None: + raise ValueError( + f"Incorrect notation for move! Full start and end position must be given. Using algebraic notation, got: {s}" + ) + + # Parse start coord + start_row = row_idx_swap(int(match.group(2))) if match.group(2) is not None else None + start_col = letter_to_num[match.group(1)] if match.group(1) is not None else None + start_coord = [start_row, start_col] + + # Parse to coord + to_row = row_idx_swap(int(match.group(3)[1])) + to_col = letter_to_num[match.group(3)[0]] + to_coord = [to_row, to_col] + + # Promotions + promotion = match.group(4) + if promotion is not None: + promotion = self.piece_str_to_id[promotion] + + # Castling + castling = False + if start_row is not None and start_col is not None: + _, piece_id = parse_piece(board_state, start_row, start_col) + if piece_id == 5 and abs(start_col - to_col) == 2: + castling = True + + return Move(start_coord, to_coord, promotion, castling) + + def _move_to_str(self, move: Move, board_state: Sequence[Sequence[int]]) -> str: + out_str = "" + start_coord, target_coord = move.start_coord, move.target_coord + + start = f"{num_to_letter[start_coord[1]]}{row_idx_swap(start_coord[0])}".lower() + out_str += start + + target = f"{num_to_letter[target_coord[1]]}{row_idx_swap(target_coord[0])}".lower() + out_str += target + + if move.promotion is not None: + out_str += self.piece_id_to_str[move.promotion] + + return out_str diff --git a/evals/elsuite/cant_do_that_anymore/chess/pieces.py b/evals/elsuite/cant_do_that_anymore/chess/pieces.py new file mode 100644 index 0000000000..9692a0170c --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/chess/pieces.py @@ -0,0 +1,263 @@ +import copy +from typing import Sequence + +from evals.elsuite.cant_do_that_anymore.chess.utils import ( + Move, + coord_within_board, + get_other_player_id, + get_path_between_coords, + has_piece_been_moved, + move_crosses_pieces, + parse_piece, +) + + +class Piece: + def __init__( + self, + piece_id: int, + white_render: str, + black_render: str, + possible_moves_white: Sequence[Sequence[int]], + possible_moves_black: Sequence[Sequence[int]], + possible_capturing_moves: Sequence[Sequence[int]] = None, + can_double_step: bool = False, + can_en_passant: bool = False, + captures_like_pawn: bool = False, + can_promote: bool = False, + can_jump_over_pieces: bool = False, + can_castle: bool = False, + ): + self.piece_id = piece_id + self.white_render = white_render + self.black_render = black_render + self.possible_moves_white = possible_moves_white + self.possible_moves_black = possible_moves_black + self.possible_capturing_moves = possible_capturing_moves + + self.can_double_step = can_double_step + self.can_en_passant = can_en_passant + self.captures_like_pawn = captures_like_pawn + self.can_promote = can_promote + self.can_jump_over_pieces = can_jump_over_pieces + self.can_castle = can_castle + + def get_piece_moves( + self, + board_state: Sequence[Sequence[int]], + player_id: str, + start_coord: Sequence[int], + previous_moves: Sequence[Move], + ) -> Sequence[Move]: + """ + Returns a sequence representing all moves this piece can make given the current environment + and rules this piece follows + """ + if player_id == "W": + possible_transformations = copy.deepcopy(self.possible_moves_white) + forward_direction = -1 + else: + possible_transformations = copy.deepcopy(self.possible_moves_black) + forward_direction = 1 + + # Get all relative transformations piece can make + if self.can_double_step: + possible_transformations += self._get_pawn_double_step_transformations( + player_id, start_coord + ) + if self.captures_like_pawn: + possible_transformations = self._remove_illegal_pawn_capture_transformations( + board_state, player_id, start_coord, possible_transformations, forward_direction + ) + if self.can_en_passant: + possible_transformations += self._get_en_passant_transformations( + board_state, start_coord, previous_moves, forward_direction + ) + + # Find all legal moves from transformations + piece_moves = self._get_moves_from_transformations( + board_state, player_id, start_coord, possible_transformations + ) + + # Add rule-specific moves + if self.can_promote: + piece_moves = self._add_promotion_moves(piece_moves) + if self.can_castle: + piece_moves += self._get_castling_possible_moves(board_state, player_id, previous_moves) + + return piece_moves + + def _get_moves_from_transformations( + self, + board_state: Sequence[Sequence[int]], + player_id: str, + start_coord: Sequence[int], + possible_transformations: Sequence[Sequence[int]], + ) -> Sequence[Move]: + """ + Given a piece's position within a board and the set of possible relative + transformations the piece can make, convert each transformation into a `Move` + object if: + 1) Transformation results in piece being on board + 2) Transformation doesn't result in piece ending up on piece of same color + 3) Transformation doesn't "jump" over other pieces, unless this piece is + allowed to do so (e.g. knight) + """ + piece_moves = [] + for move in possible_transformations: + new_row_idx = start_coord[0] + move[0] + new_col_idx = start_coord[1] + move[1] + + if not coord_within_board(new_row_idx, new_col_idx): + continue + + target_coord = [new_row_idx, new_col_idx] + target_piece_color, target_piece_id = parse_piece( + board_state, + target_coord[0], + target_coord[1], + ) + move = Move(start_coord, target_coord, None, False) + + if target_piece_color == player_id: + continue + if not self.can_jump_over_pieces and move_crosses_pieces(board_state, move): + continue + + piece_moves.append(move) + + return piece_moves + + def _get_pawn_double_step_transformations( + self, player_id: str, start_coord: Sequence[int] + ) -> Sequence[Sequence[int]]: + if player_id == "W" and start_coord[0] == 6: + return [[-2, 0]] + elif player_id == "B" and start_coord[0] == 1: + return [[2, 0]] + return [] + + def _remove_illegal_pawn_capture_transformations( + self, + board_state: Sequence[Sequence[int]], + player_id: str, + start_coord: Sequence[int], + possible_transformations: Sequence[Sequence[int]], + forward_direction: int, + ) -> Sequence[Sequence[int]]: + """ + Prevents pawns from "capturing forward" + """ + if self.piece_id != 0: + return possible_transformations + + new_possible_transformations = [] + capturing_moves = self.possible_capturing_moves + capturing_moves = [[move[0] * forward_direction, move[1]] for move in capturing_moves] + for move in possible_transformations + capturing_moves: + new_row_idx = start_coord[0] + move[0] + new_col_idx = start_coord[1] + move[1] + + if not coord_within_board(new_row_idx, new_col_idx): + continue + + target_piece_color, target_piece_id = parse_piece(board_state, new_row_idx, new_col_idx) + + if target_piece_color == "E" and move not in capturing_moves: + new_possible_transformations.append(move) + elif target_piece_color == get_other_player_id(player_id) and move in capturing_moves: + new_possible_transformations.append(move) + + return new_possible_transformations + + def _get_en_passant_transformations( + self, + board_state: Sequence[Sequence[int]], + start_coord: Sequence[int], + previous_moves: Sequence[Move], + forward_direction: int, + ) -> Sequence[Sequence[int]]: + last_move = previous_moves[-1] if len(previous_moves) > 0 else None + if last_move is not None and self.piece_id == 0: + _, last_piece_id = parse_piece( + board_state, last_move.target_coord[0], last_move.target_coord[1] + ) + + # If last move was pawn moving two tiles + if ( + last_piece_id == 0 + and abs(last_move.start_coord[0] - last_move.target_coord[0]) == 2 + ): + + # If on same row and one column apart + dx = start_coord[1] - last_move.target_coord[1] + dy = start_coord[0] - last_move.target_coord[0] + if dy == 0 and abs(dx) == 1: + return [[forward_direction, -dx]] + return [] + + def _add_promotion_moves(self, piece_moves: Sequence[Move]) -> Sequence[Move]: + new_piece_moves = [] + for move in piece_moves: + target_coord = move.target_coord + if target_coord[0] == 0 or target_coord[0] == 7: + for promotion_piece_id in [1, 2, 3, 4]: + move_promotion = copy.deepcopy(move) + move_promotion.promotion = promotion_piece_id + new_piece_moves.append(move_promotion) + else: + new_piece_moves.append(move) + + return new_piece_moves + + def _get_castling_possible_moves( + self, board_state: Sequence[Sequence[int]], player_id: str, previous_moves: Sequence[Move] + ) -> Sequence[Move]: + castling_moves = [] + if self.piece_id != 5: + return castling_moves + + def _can_pieces_castle( + king_init_coord: Sequence[int], rook_init_coord: Sequence[int], init_rook_id: int + ) -> Sequence[Move]: + if init_rook_id != 3: + return [] + + if has_piece_been_moved(king_init_coord, previous_moves) or has_piece_been_moved( + rook_init_coord, previous_moves + ): + return [] + + king_to_rook_move = Move(king_init_coord, rook_init_coord, None, False) + if move_crosses_pieces(board_state, king_to_rook_move): + return [] + + king_to_rook_path = get_path_between_coords(king_init_coord, rook_init_coord) + move = Move(king_init_coord, king_to_rook_path[1], None, True) + return [move] + + # ASSUME board init + king_init_coord = [7, 4] if player_id == "W" else [0, 4] + _, init_king_id = parse_piece(board_state, king_init_coord[0], king_init_coord[1]) + if init_king_id != 5: + return castling_moves + + # Queenside + queenside_rook_init_coord = [7, 7] if player_id == "W" else [0, 7] + _, init_rook_id = parse_piece( + board_state, queenside_rook_init_coord[0], queenside_rook_init_coord[1] + ) + castling_moves += _can_pieces_castle( + king_init_coord, queenside_rook_init_coord, init_rook_id + ) + + # Kingside + kingside_rook_init_coord = [7, 0] if player_id == "W" else [0, 0] + _, init_rook_id = parse_piece( + board_state, kingside_rook_init_coord[0], kingside_rook_init_coord[1] + ) + castling_moves += _can_pieces_castle( + king_init_coord, kingside_rook_init_coord, init_rook_id + ) + + return castling_moves diff --git a/evals/elsuite/cant_do_that_anymore/chess/utils.py b/evals/elsuite/cant_do_that_anymore/chess/utils.py new file mode 100644 index 0000000000..a92d072037 --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/chess/utils.py @@ -0,0 +1,107 @@ +from dataclasses import dataclass +from typing import Sequence + + +@dataclass +class Move: + start_coord: Sequence[int] + target_coord: Sequence[int] + promotion: int # Either None for no promotion, or int for piece id of promotion + castling: bool + + +def get_other_player_id(this_player_id: str) -> str: + if this_player_id == "W": + return "B" + elif this_player_id == "B": + return "W" + else: + raise ValueError(f"this_player_id var must be 'W' or 'B', but is: {this_player_id}") + + +def parse_piece( + board_state: Sequence[Sequence[int]], row_idx: int, col_idx: int +) -> tuple[str, int]: + """ + Returns the color and id of the piece at the given coords. + """ + piece = board_state[row_idx][col_idx] + if piece == "E": + return "E", -1 + + color = piece[0] + id = piece[1] + return color, int(id) + + +def move_crosses_pieces(board_state: Sequence[Sequence[int]], move: Move) -> bool: + path = get_path_between_coords(move.start_coord, move.target_coord) + for (x1, y1) in path: + if board_state[x1][y1] != "E": + return True + + return False + + +def has_piece_been_moved( + piece_coord: Sequence[Sequence[int]], previous_moves: Sequence[Move] +) -> bool: + for move in previous_moves: + if move.start_coord == piece_coord: + return True + if move.target_coord == piece_coord: + return True + return False + + +def coord_within_board(row_idx: int, col_idx: int) -> bool: + if row_idx < 0 or row_idx > 7: + return False + if col_idx < 0 or col_idx > 7: + return False + + return True + + +def move_within_board(move: Move) -> bool: + target_coord = move.target_coord + return coord_within_board(target_coord[0], target_coord[1]) + + +def get_path_between_coords( + start_coord: Sequence[int], target_coord: Sequence[int] +) -> Sequence[Sequence[int]]: + # Unpack the start and end points + x1, y1 = start_coord + x2, y2 = target_coord + + # Determine the steps to take in each direction + dx = 1 if x2 > x1 else -1 if x2 < x1 else 0 + dy = 1 if y2 > y1 else -1 if y2 < y1 else 0 + + path = [(x1, y1)] + while (x1, y1) != (x2, y2): + if x1 != x2: + x1 += dx + if y1 != y2: + y1 += dy + path.append((x1, y1)) + + path = path[1:-1] + return path + + +def same_color_piece_at_move_start( + board_state: Sequence[Sequence[int]], move: Move, player_color: str +) -> bool: + start_coord = move.start_coord + piece_color, _ = parse_piece(board_state, start_coord[0], start_coord[1]) + return player_color == piece_color + + +def capturing_same_color(board_state: Sequence[Sequence[int]], move: Move) -> bool: + start_coord, target_coord = move.start_coord, move.target_coord + start_piece_color, _ = parse_piece(board_state, start_coord[0], start_coord[1]) + target_piece_color, _ = parse_piece(board_state, target_coord[0], target_coord[1]) + + return start_piece_color == target_piece_color diff --git a/evals/elsuite/cant_do_that_anymore/defaults.py b/evals/elsuite/cant_do_that_anymore/defaults.py new file mode 100644 index 0000000000..8c8645f824 --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/defaults.py @@ -0,0 +1,15 @@ +TASK_DESCRIPTION = """You are playing a game of chess. Play the best move available to you. + +You must indicate which move you want to make by outputting the move using algebraic notation. Do not output anything else. +- Always write the full start and end position, for example, "a7b6" +- To castle, write the start and end position of the king, for example, "e8g8" +- Do not write "x" for capturing moves""" + +TASK_DESCRIPTION_VARIANT = """You are playing a game of a variant of chess. Play the best move available to you. + +You must indicate which move you want to make by outputting the move using algebraic notation. Do not output anything else. +- Always write the full start and end position, for example, "a7b6" +- To castle, write the start and end position of the king, for example, "e8g8" +- Do not write "x" for capturing moves + +This variant of chess does not follow the same rules as chess. All bishops now move like knights do. DO NOT move bishops diagonally, this is now considered an illegal move. All other pieces don't have their rules changed.""" diff --git a/evals/elsuite/cant_do_that_anymore/eval.py b/evals/elsuite/cant_do_that_anymore/eval.py new file mode 100644 index 0000000000..0ca6df5b0b --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/eval.py @@ -0,0 +1,201 @@ +import logging +import random +from typing import Any, Dict, Sequence, Union + +import numpy as np + +import evals.metrics +from evals.elsuite.cant_do_that_anymore.chess.board import BoardController +from evals.elsuite.cant_do_that_anymore.chess.board_test import default_board_init +from evals.elsuite.cant_do_that_anymore.chess.move_variants import ( + PIECE_ID_TO_INSTANCE, + PIECE_ID_TO_STR, + PIECE_STR_TO_ID, + VARIANT_PIECE_ID_TO_INSTANCE, +) +from evals.elsuite.cant_do_that_anymore.chess.notation import AlgebraicNotationParser +from evals.elsuite.cant_do_that_anymore.chess.pieces import Piece +from evals.elsuite.cant_do_that_anymore.chess.utils import ( + capturing_same_color, + move_within_board, + same_color_piece_at_move_start, +) +from evals.elsuite.cant_do_that_anymore.defaults import TASK_DESCRIPTION, TASK_DESCRIPTION_VARIANT +from evals.elsuite.cant_do_that_anymore.utils import ( + construct_messages, + get_binary_avg, + get_dataset_path, + get_diagonal_dataset_path, +) +from evals.eval import SolverEval +from evals.record import RecorderBase +from evals.solvers.solver import Solver, SolverResult +from evals.task_state import TaskState + +logger = logging.getLogger(__name__) + + +class CantDoThatAnymore(SolverEval): + def __init__( + self, + default_model_dataset: str = "gpt-3.5-turbo-0125", + remake_dataset_if_not_found: bool = True, + n_samples: int = 1000, + diagonal_variation: bool = False, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.default_model_dataset = default_model_dataset + self.remake_dataset_if_not_found = remake_dataset_if_not_found + self.n_samples = n_samples + self.diagonal_variation = diagonal_variation + self.rng: random.Random = random.Random(self.seed) + + def eval_sample(self, solver: Solver, sample: Any, rng: random.Random): + previous_moves, next_filtered_moves = ( + sample["previous_moves"], + sample["next_filtered_moves"], + ) + + def construct_controller(piece_id_to_instance: Dict[int, Piece]) -> BoardController: + controller = BoardController( + default_board_init, + piece_id_to_instance, + PIECE_STR_TO_ID, + PIECE_ID_TO_STR, + AlgebraicNotationParser(PIECE_STR_TO_ID, PIECE_ID_TO_STR), + ) + for move in previous_moves: + controller.update_board(move) + return controller + + default_controller = construct_controller(PIECE_ID_TO_INSTANCE) + variant_controller = construct_controller(VARIANT_PIECE_ID_TO_INSTANCE) + + # Get solver prediction. Ideally I wouldn't pass the legal_moves to the solvers, they + # should figure them out themselves, but it's necessary for the random solver + def get_solver_pred( + task_description: str, + controller: BoardController, + ) -> SolverResult: + task_state = TaskState( + task_description, + messages=construct_messages(previous_moves), + ) + return solver(task_state, **{"max_tokens": 4}) + + solver_result = get_solver_pred(TASK_DESCRIPTION, default_controller) + solver_result_variant = get_solver_pred(TASK_DESCRIPTION_VARIANT, variant_controller) + + metrics = { + "move": next_filtered_moves, + "predicted_move": solver_result.output.strip() in next_filtered_moves, + "predicted_move_in_variant": solver_result_variant.output.strip() + in next_filtered_moves, + "num_previous_moves": len(previous_moves), + "previous_moves": previous_moves, + } + + # Add violations to metrics + metrics.update( + self.get_violations( + default_controller, solver_result.output, previous_moves, "standard" + ) + ) + metrics.update( + self.get_violations( + variant_controller, solver_result_variant.output, previous_moves, "variant" + ) + ) + + evals.record.record_metrics(**metrics) + + def run(self, recorder: RecorderBase) -> dict[str, Union[float, int]]: + if self.diagonal_variation: + self.samples_jsonl = get_diagonal_dataset_path( + registry_path=self._prefix_registry_path("") + ) + else: + self.samples_jsonl = get_dataset_path( + solver=self._solver, + registry_path=self._prefix_registry_path(""), + remake_dataset_if_not_found=self.remake_dataset_if_not_found, + default_model_dataset=self.default_model_dataset, + ) + samples = self.get_samples() + samples = self.rng.sample(samples, min(self.n_samples, len(samples))) + + self.eval_all_samples(recorder, samples) + metrics = recorder.get_metrics() + + predicted_move_proportion = get_binary_avg(metrics, "predicted_move") + predicted_move_in_variant_proportion = get_binary_avg(metrics, "predicted_move_in_variant") + + avg_num_previous_moves = sum([i["num_previous_moves"] for i in metrics]) / len(metrics) + std_num_previous_moves = np.std([i["num_previous_moves"] for i in metrics]) + + delta = predicted_move_in_variant_proportion - predicted_move_proportion + variant_impact_factor = (delta / predicted_move_proportion) if predicted_move_proportion != 0 else 0 + + results = { + "variant_impact_factor": variant_impact_factor, + "delta": delta, + "predicted_move_proportion": predicted_move_proportion, + "predicted_move_in_variant_proportion": predicted_move_in_variant_proportion, + "avg_num_previous_moves": avg_num_previous_moves, + "std_num_previous_moves": std_num_previous_moves, + } + + # Add violations + violation_keys = [i for i in metrics[0].keys() if "violation" in i] + violation_results = { + f"{name}_rate": get_binary_avg(metrics, name) for name in violation_keys + } + results.update(violation_results) + + return results + + def get_violations( + self, + controller: BoardController, + solver_output: str, + previous_moves: Sequence[str], + variant_name: str, + ) -> dict: + solver_color = "W" if len(previous_moves) % 2 == 0 else "B" + + piece_moved_outside_board = False + moving_invalid_piece = False + piece_capturing_same_color = False + + violation_metrics = {} + try: + move = controller.notation_parser._str_to_move( + solver_output, controller.board.board_state + ) + + piece_moved_outside_board = not move_within_board(move) + moving_invalid_piece = not same_color_piece_at_move_start( + controller.board.board_state, move, solver_color + ) + piece_capturing_same_color = capturing_same_color(controller.board.board_state, move) + incorrect_notation = False + except (ValueError, KeyError): + incorrect_notation = True + + violation = ( + piece_moved_outside_board + or moving_invalid_piece + or piece_capturing_same_color + or incorrect_notation + ) + violation_metrics = { + f"{variant_name}_violation": violation, + f"{variant_name}_violation_moved_outside_board": piece_moved_outside_board, + f"{variant_name}_violation_moving_invalid_piece": moving_invalid_piece, + f"{variant_name}_violation_capturing_same_color": piece_capturing_same_color, + f"{variant_name}_violation_incorrect_notation": incorrect_notation, + } + return violation_metrics diff --git a/evals/elsuite/cant_do_that_anymore/scripts/dataset_creation.py b/evals/elsuite/cant_do_that_anymore/scripts/dataset_creation.py new file mode 100644 index 0000000000..e0c7a0265a --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/scripts/dataset_creation.py @@ -0,0 +1,312 @@ +import argparse +import copy +import os +import pathlib +from typing import Sequence + +import chess.pgn +import requests +import zstandard +from tqdm import tqdm + +from evals.elsuite.cant_do_that_anymore.chess.board import BoardController +from evals.elsuite.cant_do_that_anymore.chess.utils import Move, parse_piece +from evals.elsuite.cant_do_that_anymore.utils import ( + assert_boards_consistent, + dump_sequence_to_jsonl, + initialise_boards, +) + + +def prepare_lichess_2014_dataset(out_dir: str) -> str: + """ + Downloads and extracts Lichess 2014 April dataset, returns the + path to the extracted .pgn file + """ + fname = "lichess_db_standard_rated_2014-04.pgn.zst" + raw_data_out_path = os.path.join(out_dir, fname) + if not os.path.exists(raw_data_out_path): + url = "https://database.lichess.org/standard/" + fname + r = requests.get(url) + open(raw_data_out_path, "wb").write(r.content) + + out_path = os.path.join(out_dir, "pgn_data.pgn") + if not os.path.exists(out_path): + input_file = pathlib.Path(raw_data_out_path) + with open(input_file, "rb") as compressed: + decomp = zstandard.ZstdDecompressor() + with open(out_path, "wb") as destination: + decomp.copy_stream(compressed, destination) + + return out_path + + +class MoveFilter: + def __call__( + self, + default_controller: BoardController, + variant_controller: BoardController, + move: chess.Move, + player_id: str, + ) -> bool: + raise NotImplementedError() + + +class SpecialMoveFilter(MoveFilter): + """ + Filters for moves that are: + 1) Legal under the normal rules of chess + 2) Illegal under the variant's rules (i.e. bishop is moved) + """ + + def __call__( + self, + default_controller: BoardController, + variant_controller: BoardController, + move: Move, + player_id: str, + ) -> bool: + if not is_move_illegal(default_controller, move, player_id) and is_move_illegal( + variant_controller, move, player_id + ): + return True + + return False + + +class ControlMoveFilter(MoveFilter): + """ + Finds positions where solvers should have (almost) equivalent predictions under + both sets of rules + Filters for moves that are: + 1) Legal under both the normal and variant's rules of chess + 2) Are on a board containing no bishops + 3) Are on a board where no pawns are close to promoting; neither players + pawns are in their last three rows + 4) Are on a board with more than four pieces between both players + """ + + def __call__( + self, + default_controller: BoardController, + variant_controller: BoardController, + move: Move, + player_id: str, + ) -> bool: + if is_move_illegal(default_controller, move, player_id): + return False + if is_move_illegal(variant_controller, move, player_id): + return False + + board_state = default_controller.board.board_state + num_pieces = 0 + for row_idx in range(8): + for col_idx in range(8): + _, piece_id = parse_piece(board_state, row_idx, col_idx) + if piece_id == 2: + return False + elif piece_id == 0: + if player_id == "W" and row_idx <= 2: + return False + elif player_id == "B" and row_idx >= 5: + return False + elif piece_id != -1: + num_pieces += 1 + + if num_pieces < 4: + return False + + return True + + +def is_move_illegal(controller: BoardController, move: chess.Move, player_id: str) -> bool: + legal_moves = controller.get_player_legal_moves(player_id) + if move in legal_moves: + return False + return True + + +def find_specific_moves_in_game( + game: chess.pgn.Game, + game_idx: int, + move_filter: MoveFilter, + default_controller: BoardController, + variant_controller: BoardController, + their_controller: chess.Board, + filter_if_found_previous: bool, +) -> Sequence[dict]: + """ + Given a game, finds all moves that satisfy the given filter + If filter_if_found_previous is True, only finds first move in game that + satisfies filter + """ + player_id = "W" + previous_moves = [] + filtered_moves = [] + for move in game.mainline_moves(): + move = move.uci() + + if move_filter(default_controller, variant_controller, move, player_id): + filtered_moves.append( + { + "game_idx": game_idx, + "previous_moves": copy.deepcopy(previous_moves), + "next_filtered_moves": [move], + "any_previous_move_found": len(filtered_moves) > 0, + } + ) + if filter_if_found_previous: + break + + # Ensure my implementation is correct + assert_boards_consistent(default_controller, their_controller, player_id) + + # Update boards + default_controller.update_board(move) + their_controller.push_san(move) + + variant_controller.board.board_state = default_controller.board.board_state + variant_controller.previous_moves = default_controller.previous_moves + + player_id = "B" if player_id == "W" else "W" + previous_moves.append(move) + + return filtered_moves + + +def create_dataset_of_specific_moves( + pgn_path: str, + move_filter: MoveFilter, + target_num_examples: int, + filter_if_found_previous: bool, + filter_for_unique_previous_moves: bool, + continuously_save: bool, + out_path: str, +): + """ + Iterates over games in dataset and filters move according to the given move_filter + If filter_for_unique_previous_moves is True, filter to only include moves that have + unique sets of previous moves + If continuously_save is True, saves dataset everytime it is updated + """ + pgn = open(pgn_path) + dataset = [] + unique_previous_moves = set() + + t_bar = tqdm(total=target_num_examples) + game_idx = 0 + while True: + game = chess.pgn.read_game(pgn) + if game is None: + break + + default_controller, variant_controller, their_controller = initialise_boards() + filtered_moves = find_specific_moves_in_game( + game, + game_idx, + move_filter, + default_controller, + variant_controller, + their_controller, + filter_if_found_previous, + ) + + if filter_for_unique_previous_moves: + for example in filtered_moves: + previous_moves = example["previous_moves"] + if set(previous_moves) not in unique_previous_moves: + dataset.append(example) + unique_previous_moves.add(frozenset(previous_moves)) + t_bar.update(1) + if continuously_save: + dump_sequence_to_jsonl(dataset, out_path) + + elif len(filtered_moves) > 0: + dataset += filtered_moves + t_bar.update(len(filtered_moves)) + if continuously_save: + dump_sequence_to_jsonl(dataset, out_path) + + game_idx += 1 + t_bar.set_description(f"Num games examined: {game_idx}") + + if len(dataset) >= target_num_examples: + break + + return dataset + + +def main(args: argparse.Namespace): + lichess_path = prepare_lichess_2014_dataset(args.out_dir) + + if args.make_special_moves: + move_filter = SpecialMoveFilter() + dataset_name = "special_moves_dataset.jsonl" + out_path = os.path.join(args.out_dir, dataset_name) + dataset = create_dataset_of_specific_moves( + lichess_path, + move_filter, + target_num_examples=args.n_moves, + filter_if_found_previous=args.filter_if_found_previous, + filter_for_unique_previous_moves=args.filter_for_unique_previous_moves, + continuously_save=args.continuously_save, + out_path=out_path, + ) + dump_sequence_to_jsonl(dataset, out_path) + + if args.make_control_moves: + move_filter = ControlMoveFilter() + dataset_name = "control_moves_dataset.jsonl" + out_path = os.path.join(args.out_dir, dataset_name) + dataset = create_dataset_of_specific_moves( + lichess_path, + move_filter, + target_num_examples=args.n_moves, + filter_if_found_previous=args.filter_if_found_previous, + filter_for_unique_previous_moves=args.filter_for_unique_previous_moves, + continuously_save=args.continuously_save, + out_path=out_path, + ) + dump_sequence_to_jsonl(dataset, out_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + + parser.add_argument("--n_moves", type=int, default=5000) + parser.add_argument( + "--out_dir", type=str, default="./evals/registry/data/cant_do_that_anymore/" + ) + parser.add_argument( + "--make_special_moves", + action="store_true", + help="Whether to search and build a dataset of special moves", + default=False, + ) + parser.add_argument( + "--make_control_moves", + action="store_true", + help="Whether to search and build a dataset of control moves", + default=False, + ) + parser.add_argument( + "--filter_if_found_previous", + action="store_true", + help="Whether to filter out moves that have had previous moves that satisfy the filtering condition.", + default=False, + ) + parser.add_argument( + "--filter_for_unique_previous_moves", + action="store_true", + help="Whether to only search for moves with unique previous moves (up to such position at the move)", + default=False, + ) + parser.add_argument( + "--continuously_save", + action="store_true", + help="Whether to save the dataset everytime a new example has been found", + default=False, + ) + args = parser.parse_args() + + main(args) diff --git a/evals/elsuite/cant_do_that_anymore/scripts/diagonal_dataset_creation.py b/evals/elsuite/cant_do_that_anymore/scripts/diagonal_dataset_creation.py new file mode 100644 index 0000000000..491acf3c95 --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/scripts/diagonal_dataset_creation.py @@ -0,0 +1,316 @@ +import argparse +import copy +import os +import random +from typing import Optional, Sequence + +from stockfish import Stockfish +from tqdm import tqdm + +from evals.elsuite.cant_do_that_anymore.chess.board import BoardController +from evals.elsuite.cant_do_that_anymore.chess.move_variants import DIAGONAL_MOVES +from evals.elsuite.cant_do_that_anymore.chess.utils import ( + Move, + coord_within_board, + move_crosses_pieces, + parse_piece, +) +from evals.elsuite.cant_do_that_anymore.utils import dump_sequence_to_jsonl, initialise_boards + +# NOTE change threads, hash depending on hardware +# https://pypi.org/project/stockfish/ +STOCKFIAH_MOVES_CONSIDERED = 5 +STOCKFISH_DEPTH = 18 +STOCKFISH_PARAMS = { + "Debug Log File": "", + "Contempt": 0, + "Min Split Depth": 0, + "Threads": 8, + "Ponder": "false", + "Hash": 4096, + "MultiPV": 1, + "Skill Level": 10, + "Move Overhead": 10, + "Minimum Thinking Time": 20, + "Slow Mover": 100, + "UCI_Chess960": "true", + "UCI_LimitStrength": "false", + "UCI_Elo": 1500, +} + + +def get_stockfish_move(stockfish: Stockfish, num_moves_to_consider: int) -> str: + """ + Gets the next move predicted by stockfish. Gets top n predictions and + selects randomly weighted by each move's centipawn value + Filters out bishop promotions, since our variant shouldn't have bishops + """ + # Get top moves, filter out bad ones + top_moves = stockfish.get_top_moves(num_moves_to_consider) + + # Filter out bishop promotions + top_moves = [i for i in top_moves if not i["Move"].endswith("b")] + + # If stockfish considers moves that it knows will lead to mate, only + # select from these moves + mates = [i for i in top_moves if i["Mate"] is not None] + if len(mates) > 0: + top_moves = mates + + # Ensures centipawn value isn't None + if all([i["Centipawn"] is None for i in top_moves]): + for move in top_moves: + move["Centipawn"] = 1 + else: + top_moves = [i for i in top_moves if i["Centipawn"] is not None] + + # Makes all centipawns positive + min_centipawn_value = min([i["Centipawn"] for i in top_moves]) + for move in top_moves: + move["Centipawn"] += abs(min_centipawn_value) + + # Normalise centipawn to a probability distribution + centipawn_sum = sum([i["Centipawn"] for i in top_moves]) + for move in top_moves: + move["prob"] = move["Centipawn"] / centipawn_sum + + # Pick move randomly + prob = random.uniform(0, 1) + selected_move = None + for move in top_moves: + prob -= move["prob"] + if prob <= 0: + selected_move = move["Move"] + break + + return selected_move + + +def parse_stockfish_move(controller: BoardController, move: str) -> str: + """ + When stockfish outputs a castling move, the move is from the kings position to the + rooks position, e.g. "e8a8" + In my framework castling is indicated by the start+end position of the king, e.g. "e8c8" + This functions converts the stockfish notation to my notation + """ + move = controller.notation_parser._str_to_move(move, controller.board.board_state) + _, piece_id = parse_piece( + controller.board.board_state, move.start_coord[0], move.start_coord[1] + ) + + # If castling move + dy = move.target_coord[1] - move.start_coord[1] + if piece_id == 5: + if dy > 2 or dy < -2: + direction = dy / abs(dy) + if direction == 1: # Kingside castling + move.target_coord = [move.target_coord[0], move.target_coord[1] - 1] + else: # Queenside castling + move.target_coord = [move.target_coord[0], move.target_coord[1] + 2] + + move = controller.notation_parser._move_to_str(move, controller.board.board_state) + return move + + +def get_bishop_diagonal_moves(controller: BoardController, player_id: str) -> Sequence[str]: + """ + Gets all possible diagonal moves that a bishop could make on a board, even if the bishop isn't + allowed to move diagonally under the board's rules + """ + # Find all bishops on board + bishop_coords = [] + board_state = controller.board.board_state + for row_idx in range(8): + for col_idx in range(8): + piece_color, piece_id = parse_piece(board_state, row_idx, col_idx) + if piece_color == player_id and piece_id == 2: + bishop_coords.append([row_idx, col_idx]) + + # Find all possible diagonal movements of each bishop + bishop_diagonal_moves = [] + for row_idx, col_idx in bishop_coords: + for transformation in DIAGONAL_MOVES: + new_coord = [row_idx + transformation[0], col_idx + transformation[1]] + move = Move([row_idx, col_idx], new_coord, promotion=None, castling=False) + + # If piece doesn't move + if transformation[0] == 0 and transformation[1] == 0: + continue + # If transformation moves piece outside board + if not coord_within_board(new_coord[0], new_coord[1]): + continue + # If transformation moves onto piece of same color + piece_color, _ = parse_piece(controller.board.board_state, new_coord[0], new_coord[1]) + if piece_color == player_id: + continue + # If move crosses friendly pieces + if move_crosses_pieces(controller.board.board_state, move): + continue + + move = controller.notation_parser._move_to_str(move, controller.board.board_state) + bishop_diagonal_moves.append(move) + + return bishop_diagonal_moves + + +def find_specific_moves_in_game( + game_idx: int, + variant_controller: BoardController, + filter_if_found_previous: bool, + max_moves: int, +) -> Sequence[dict]: + """ + Simulates an individual game, using the variant's rules. Finds all possible + diagonal moves from bishops (even though moving bishops diagonally is + illegal under the variant) + If filter_if_found_previous is True, only finds the first position with possible + bishop moves + """ + stockfish = Stockfish(depth=STOCKFISH_DEPTH, parameters=STOCKFISH_PARAMS) + # HACK to have stockfish play our variant, just swap out the bishops for knights + # then later pretend the knights are bishops + stockfish.set_fen_position("rnnqknnr/pppppppp/8/8/8/8/PPPPPPPP/RNNQKNNR w KQkq - 0 1") + previous_moves = [] + player_id = "W" + + # Get ELO of each player + elos = [1350, 1000] + random.shuffle(elos) + white_elo, black_elo = elos + + bishop_diagonal_moves = [] + for _ in range(max_moves): + if player_id == "W": + stockfish.set_elo_rating(white_elo) + else: + stockfish.set_elo_rating(black_elo) + + # Find all diagonal bishop moves from this position + found_moves = get_bishop_diagonal_moves(variant_controller, player_id) + if len(found_moves) > 0: + bishop_diagonal_moves.append( + { + "game_idx": game_idx, + "previous_moves": copy.deepcopy(previous_moves), + "next_filtered_moves": found_moves, + } + ) + if filter_if_found_previous: + break + + move = get_stockfish_move(stockfish, STOCKFIAH_MOVES_CONSIDERED) + stockfish.make_moves_from_current_position([move]) + + # Parse into notation that is compatible with my framework + move = parse_stockfish_move(variant_controller, move) + variant_controller.update_board(move) + + player_id = "B" if player_id == "W" else "W" + previous_moves.append(move) + + # If checkmate or stalemate, end + if len(variant_controller.get_player_legal_moves(player_id)) == 0: + break + + return bishop_diagonal_moves + + +def create_bishop_diagonal_dataset( + target_num_examples: int, + max_moves: int, + filter_if_found_previous: bool, + filter_for_unique_previous_moves: bool, + continuously_save: bool, + out_path: Optional[str], +) -> Sequence[dict]: + """ + Simulates stockfish games and finds possible diagonal moves that could be + made by bishops. + If filter_if_found_previous is True, finds the first move that satisfies this + criteria in each game + If filter_for_unique_previous_moves is True, filters to ensure each + example has a unique set of previous moves + If continuously_save is True, saves dataset everytime it is updated + """ + dataset = [] + unique_previous_moves = set() + + t_bar = tqdm(total=target_num_examples) + game_idx = 0 + while True: + _, variant_controller, _ = initialise_boards() + filtered_moves = find_specific_moves_in_game( + game_idx, + variant_controller, + filter_if_found_previous, + max_moves, + ) + + if filter_for_unique_previous_moves: + for example in filtered_moves: + previous_moves = example["previous_moves"] + if set(previous_moves) not in unique_previous_moves: + dataset.append(example) + unique_previous_moves.add(frozenset(previous_moves)) + t_bar.update(1) + if continuously_save: + dump_sequence_to_jsonl(dataset, out_path) + + elif len(filtered_moves) > 0: + dataset += filtered_moves + t_bar.update(len(filtered_moves)) + if continuously_save: + dump_sequence_to_jsonl(dataset, out_path) + + game_idx += 1 + t_bar.set_description(f"Num games examined: {game_idx}") + + if len(dataset) >= target_num_examples: + break + + return dataset + + +def main(args: argparse.Namespace): + dataset_name = "diagonal_moves_dataset.jsonl" + out_path = os.path.join(args.out_dir, dataset_name) + dataset = create_bishop_diagonal_dataset( + target_num_examples=args.n_moves, + max_moves=args.max_moves, + filter_if_found_previous=args.filter_if_found_previous, + filter_for_unique_previous_moves=args.filter_for_unique_previous_moves, + continuously_save=args.continuously_save, + out_path=out_path, + ) + dump_sequence_to_jsonl(dataset, out_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + + parser.add_argument("--n_moves", type=int, default=5000) + parser.add_argument("--max_moves", type=int, default=50) + parser.add_argument( + "--out_dir", type=str, default="./evals/registry/data/cant_do_that_anymore/" + ) + parser.add_argument( + "--filter_if_found_previous", + action="store_true", + help="Whether to filter out moves that have had previous moves that satisfy the filtering condition", + default=False, + ) + parser.add_argument( + "--filter_for_unique_previous_moves", + action="store_true", + help="Whether to only search for moves with unique previous moves (up to such position at the move)", + default=False, + ) + parser.add_argument( + "--continuously_save", + action="store_true", + help="Whether to save the dataset everytime a new example has been found", + default=False, + ) + args = parser.parse_args() + + main(args) diff --git a/evals/elsuite/cant_do_that_anymore/scripts/make_plots.py b/evals/elsuite/cant_do_that_anymore/scripts/make_plots.py new file mode 100644 index 0000000000..bd0ea4d5cc --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/scripts/make_plots.py @@ -0,0 +1,128 @@ +import argparse +import os +from pathlib import Path +from typing import Sequence + +import pandas as pd +from matplotlib import pyplot as plt + +from evals.elsuite.cant_do_that_anymore.chess.utils import parse_piece +from evals.elsuite.cant_do_that_anymore.utils import initialise_boards +from evals.utils.log_utils import ( + extract_individual_results, + extract_spec, + get_final_results_from_dir, +) + + +def extract_results(datadir: Path) -> pd.DataFrame: + df_agg = [] # Aggregated results + df_samples = [] # Per sample results + for path, results in sorted(list(get_final_results_from_dir(datadir).items())): + spec = extract_spec(path) + solver_path = Path(spec["completion_fns"][0]) + model = solver_path.name + solver = solver_path.parent.name + # Remove root section of path, which is the eval name + solver_path = solver_path.relative_to(solver_path.parts[0]) + # Aggregated + df_agg.append( + { + "solver_path": str(solver_path), + "model": str(model), + "solver": str(solver), + **spec["run_config"]["eval_spec"]["args"], + **results, + } + ) + # Per-sample + for res in extract_individual_results(path): + df_samples.append( + { + "solver_path": str(solver_path), + "model": str(model), + "solver": str(solver), + **spec["run_config"]["eval_spec"]["args"], + **res, + } + ) + df_agg = pd.DataFrame(df_agg) + df_samples = pd.DataFrame(df_samples) + return df_agg, df_samples + + +def render_results(df: pd.DataFrame, out_dir: Path): + agg_operations = { + "predicted_move_proportion": ["mean", "sem"], + "predicted_move_in_variant_proportion": ["mean", "sem"], + } + df = df.groupby("solver_path").agg(agg_operations).reset_index() + df = df.round(2) + print(df.to_csv(index=False)) + df.to_csv(os.path.join(out_dir, "results.csv"), index=False) + + +def compute_num_previous_bishop_moves(previous_moves: Sequence[str]) -> int: + controller, _, _ = initialise_boards() + + num_previous_bishop_moves = 0 + for move in previous_moves: + start_coord = controller.notation_parser._str_to_move( + move, controller.board.board_state + ).start_coord + _, piece_id = parse_piece(controller.board.board_state, start_coord[0], start_coord[1]) + if piece_id == 2: + num_previous_bishop_moves += 1 + + controller.update_board(move) + + return num_previous_bishop_moves + + +def plot_diagonal_bishop_results(df: pd.DataFrame, out_dir: Path): + # Get number of previous bishop moves + df["num_previous_bishop_moves"] = [ + compute_num_previous_bishop_moves(i) for i in df["previous_moves"] + ] + + # Calculate headline metrics per solver, and number of previous moves + agg_operations = { + "predicted_move_in_variant": ["mean"], + } + df = df.groupby(["solver_path", "num_previous_bishop_moves"]).agg(agg_operations).reset_index() + + # Plot separately for each solver + for model, group in df.groupby("solver_path"): + plt.plot( + group["num_previous_bishop_moves"], + group["predicted_move_in_variant"], + label=model, + ) + + plt.xlabel("Num previous bishop moves") + plt.ylabel("Proportion of (illegal) predicted diagonal bishop moves") + plt.ylim([0, 1]) + plt.legend() + plt.savefig(os.path.join(out_dir, "diagonal.png")) + plt.show() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--log_dir", "-d", type=str, required=True) + parser.add_argument("--out_dir", "-o", type=str, required=True) + parser.add_argument("--diagonal_variant", action="store_true", default=False) + args = parser.parse_args() + log_dir = Path(args.log_dir) + out_dir = Path(args.out_dir) + out_dir.mkdir(exist_ok=True, parents=True) + + df_agg, df_samples = extract_results(log_dir) + render_results(df_agg, out_dir) + + if args.diagonal_variant: + plot_diagonal_bishop_results(df_samples, out_dir) + + +if __name__ == "__main__": + main() diff --git a/evals/elsuite/cant_do_that_anymore/scripts/run_experiments.sh b/evals/elsuite/cant_do_that_anymore/scripts/run_experiments.sh new file mode 100755 index 0000000000..68fe4ac5e7 --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/scripts/run_experiments.sh @@ -0,0 +1,67 @@ +#!/bin/bash +logdir=./logs +outputdir=./outputs + +timestamp=$(date +%Y%m%d_%H%M%S) +logpathbase=$logdir/$timestamp/ + +mkdir -p ${logpathbase} + +declare -a SOLVERS_ZEROSHOT=( + "generation/direct/gpt-3.5-turbo" + "chess/generation/direct/gpt-3.5-turbo-instruct" + "generation/direct/gpt-4-turbo-preview" + "chess/generation/direct/gpt-4-base" +) + +# See if variant was indicated +run_diagonal_variant=1 +for arg in "$@" +do + if [[ $arg == "--no_diagonal_variant" ]]; then + run_diagonal_variant=0 + break + fi +done + +# TODO CoT solvers + +echo Running experiments and logging to $logpathbase + +for run_idx in {0..2} +do + for solver in "${SOLVERS_ZEROSHOT[@]}" + do + log_name=${solver//\//-} + oaieval $solver cant_do_that_anymore \ + --record_path ${logpathbase}run_${run_idx}_${log_name}.log \ + --extra_eval_params n_samples=1000 \ + --seed ${run_idx} + done +done + +echo Done running experiments, all logs in $logpathbase + +echo Producing plots, outputs to $outputdir +python make_plots.py --log_dir $logpathbase --out_dir $outputdir + +if [[ $run_diagonal_variant -eq 1 ]]; then + echo Running diagonal experiment and logging to $logpathbase + + for run_idx in {0..2} + do + for solver in "${SOLVERS_ZEROSHOT[@]}" + do + log_name=${solver//\//-} + oaieval $solver cant_do_that_anymore.all_diagonal \ + --record_path ${logpathbase}run_${run_idx}_${log_name}.log \ + --extra_eval_params n_samples=1000 \ + --seed ${run_idx} + done + done + + echo Done running experiments, all logs in $logpathbase + + echo Producing plots, outputs to $outputdir + python make_plots.py --log_dir $logpathbase --out_dir $outputdir --diagonal_variant +fi \ No newline at end of file diff --git a/evals/elsuite/cant_do_that_anymore/utils.py b/evals/elsuite/cant_do_that_anymore/utils.py new file mode 100644 index 0000000000..519aad8596 --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/utils.py @@ -0,0 +1,250 @@ +import json +import logging +import os +from multiprocessing.pool import ThreadPool +from typing import Sequence + +import chess +from tqdm import tqdm + +from evals.elsuite.cant_do_that_anymore.chess.board import BoardController +from evals.elsuite.cant_do_that_anymore.chess.board_test import default_board_init +from evals.elsuite.cant_do_that_anymore.chess.move_variants import ( + PIECE_ID_TO_INSTANCE, + PIECE_ID_TO_STR, + PIECE_STR_TO_ID, + VARIANT_PIECE_ID_TO_INSTANCE, +) +from evals.elsuite.cant_do_that_anymore.chess.notation import AlgebraicNotationParser +from evals.elsuite.cant_do_that_anymore.defaults import TASK_DESCRIPTION +from evals.record import DummyRecorder, RecorderBase +from evals.solvers.solver import DummySolver, Solver +from evals.task_state import Message, TaskState + +logger = logging.getLogger(__name__) + + +def construct_messages(previous_moves: Sequence[str]) -> Sequence[Message]: + """ + Creates list of Message's containing the previous chess moves. The last + Message is always from the "user" + """ + solver_is_white = len(previous_moves) % 2 == 0 + messages = [] + current_player = "assistant" if solver_is_white else "user" + for move in previous_moves: + messages.append(Message(current_player, move)) + # toggle current player + current_player = "assistant" if current_player == "user" else "user" + + return messages + + +def dump_sequence_to_jsonl(data: Sequence[dict], path: str): + with open(path, "w+") as f: + for example in data: + example = json.dumps(example) + f.write(f"{example}\n") + + +def load_sequence_from_jsonl(path: str) -> Sequence[dict]: + data = [] + with open(path, "r") as f: + for line in f: + line = json.loads(line) + data.append(line) + + return data + + +def initialise_boards() -> tuple[BoardController, BoardController, chess.Board]: + """ + Initialises local chess framework, and framework from + python-chess library + """ + default_controller = BoardController( + default_board_init, + PIECE_ID_TO_INSTANCE, + PIECE_STR_TO_ID, + PIECE_ID_TO_STR, + AlgebraicNotationParser(PIECE_STR_TO_ID, PIECE_ID_TO_STR), + ) + variant_controller = BoardController( + default_board_init, + VARIANT_PIECE_ID_TO_INSTANCE, + PIECE_STR_TO_ID, + PIECE_ID_TO_STR, + AlgebraicNotationParser(PIECE_STR_TO_ID, PIECE_ID_TO_STR), + ) + their_controller = chess.Board() + + return default_controller, variant_controller, their_controller + + +def assert_boards_consistent( + controller: BoardController, their_controller: chess.Board, player_id: str +): + """ + Checks both boards have consistent states by ensuring both have same set of legal moves + """ + our_legal_moves = sorted(controller.get_player_legal_moves(player_id)) + their_legal_moves = sorted([str(i) for i in their_controller.legal_moves]) + if our_legal_moves != their_legal_moves: + our_additional_moves = list(set(our_legal_moves) - set(their_legal_moves)) + their_additional_moves = list(set(their_legal_moves) - set(our_legal_moves)) + assert False, f""" + Inconsistent legal moves between the boards! + Our legal moves: {our_legal_moves}, + Their legal moves: {their_legal_moves}, + Moves we had they didnt: {our_additional_moves}, + Moves they had we didn't: {their_additional_moves}, + Board state:\n{controller.board.board_state} + """ + + +def does_solver_predict_move( + solver: Solver, + recorder: RecorderBase, + task_description: str, + special_move: str, + previous_moves: Sequence[str], +): + task_state = TaskState( + task_description, + construct_messages(previous_moves), + ) + + with recorder.as_default_recorder(-1): + solver_result = solver(task_state, **{"max_tokens": 4}) + pred_str = solver_result.output.strip() + + if pred_str == special_move: + return True + + return False + + +def process_example(work_input: dict): + solver, recorder, example, task_description = ( + work_input["solver"], + work_input["recorder"], + work_input["example"], + work_input["task_description"], + ) + special_move, previous_moves = example["special_move"], example["previous_moves"] + + predicts_move = does_solver_predict_move( + solver, + recorder, + task_description, + special_move, + previous_moves, + ) + return predicts_move, example + + +def get_solver_predictions( + solver: Solver, + recorder: RecorderBase, + special_moves_dataset: Sequence[dict], + n_threads: int, + task_description: str, +) -> Sequence[dict]: + """ + Filter to find all special moves that the solver would have predicted under the normal + rules of chess with temp=0, then dump this dataset + """ + solver_moves_dataset = [] + work_items = [ + { + "solver": solver, + "recorder": recorder, + "example": example, + "task_description": task_description, + } + for example in special_moves_dataset + ] + + t_bar = tqdm(total=len(special_moves_dataset)) + with ThreadPool(n_threads) as pool: + iter = pool.imap_unordered(process_example, work_items) + + for result in (t_bar := tqdm(iter, total=len(work_items))): + predicts_move, example = result + if predicts_move: + solver_moves_dataset.append(example) + t_bar.set_description(f"Dataset size: {len(solver_moves_dataset)}") + + return solver_moves_dataset + + +def get_dataset_path( + solver: Solver, + registry_path: str, + remake_dataset_if_not_found: bool, + default_model_dataset: str, +) -> str: + """ + This dataset requires each evaluated model to have its own dataset. We get the exact + model being exaluated, check if a dataset exists for it, if not we generate one + """ + recorder = DummyRecorder(None) + with recorder.as_default_recorder("x"): + solver_version = solver.model_version + + # If nested solver, convert returned dict to str + if isinstance(solver_version, dict): + solver_version = json.dumps(solver_version) + + all_datasets_path = os.path.join(registry_path, "cant_do_that_anymore") + + # Check if dataset exists + solver_dataset_path = os.path.join(all_datasets_path, f"{solver_version}_dataset.jsonl") + if os.path.exists(solver_dataset_path): + return solver_dataset_path + + # Remake, or load default + if isinstance(solver, DummySolver): + return f"cant_do_that_anymore/{default_model_dataset}_dataset.jsonl" + elif remake_dataset_if_not_found: + logger.warning( + f"Generating dataset for {solver_version}! Ideally the solver should be using temperature=0 when creating the dataset, " + "otherwise generated dataset will be of a slightly different distribution" + ) + create_dataset(solver, recorder, solver_dataset_path, all_datasets_path) + return solver_dataset_path + else: + logger.warning( + f"Dataset for {solver_version} wasn't found! Using the dataset for {default_model_dataset} instead." + ) + return f"cant_do_that_anymore/{default_model_dataset}_dataset.jsonl" + + +def create_dataset( + solver: Solver, recorder: RecorderBase, solver_dataset_path: str, all_datasets_path: str +): + threads = int(os.environ.get("EVALS_THREADS", "10")) + + special_moves_dataset = load_sequence_from_jsonl( + os.path.join(all_datasets_path, "special_moves_dataset.jsonl") + ) + solver_moves_dataset = get_solver_predictions( + solver, + recorder, + special_moves_dataset, + n_threads=threads, + task_description=TASK_DESCRIPTION, + ) + dump_sequence_to_jsonl(solver_moves_dataset, solver_dataset_path) + + +def get_diagonal_dataset_path( + registry_path: str, +) -> str: + return os.path.join(registry_path, "cant_do_that_anymore/diagonal_moves_dataset.jsonl") + + +def get_binary_avg(metrics: dict, key: str) -> float: + positive_examples = [i for i in metrics if i[key]] + avg = len(positive_examples) / len(metrics) + return avg diff --git a/evals/elsuite/error_recovery/README.md b/evals/elsuite/error_recovery/README.md new file mode 100644 index 0000000000..78f6ceb726 --- /dev/null +++ b/evals/elsuite/error_recovery/README.md @@ -0,0 +1,96 @@ +# Error Recovery + +This eval evaluates how well LMs can recover from having made reasoning errors while solving a problem. + +## Usage + +Run with: + +``` +oaieval error-recovery +``` + +We suggest using `error_recovery/gpt-4-0613` or `error_recovery/gpt-3.5-turbo-0613` as default choices for `` + +For more examples of running this eval, see `scripts/run_experiments.sh` + +## Dataset + +The dataset consists of 1436 problems from [BIG-Bench Mistake](https://github.com/WHGTyen/BIG-Bench-Mistake) which have reasoning traces with mistakes in them, and incorrect answers. + +## Evaluation Process + +For each example, we run the solver on three variations of the problem: + +1) No Reasoning (NR) variation - Give the solver the question and ask it to generate reasoning and an answer. Example below: +``` +[SYSTEM] +Solve the given problem, writing your reasoning along the way. + +[USER] +Solve this problem for x: x = ((3 + 3) * 5) + 2 +``` + + +2) Correct Reasoning (CR) variation - Give the solver the question and some correct reasoning steps and have it keep generating reasoning and an answer. Example below: + +``` +[SYSTEM] +Solve the given problem, writing your reasoning along the way. + +[USER] +Solve this problem for x: x = ((3 + 3) * 5) + 2 + +[ASSISTANT] +3 + 3 = 6 +``` + +3) Incorrect Reasoning (IR) variation - Give the solver the question, the same correct reasoning steps as before, and an incorrect reasoning step and have it keep generating reasoning and an answer. + +``` +[SYSTEM] +Solve the given problem, writing your reasoning along the way. + +[USER] +Solve this problem for x: x = ((3 + 3) * 5) + 2 + +[ASSISTANT] +3 + 3 = 6 +6 * 5 = 35 +``` + + +## Metrics + +The below are the key metrics of this eval: + +| Metric | Interpretation | +| --- | --- | +| `NR_correct_rate` | The accuracy of the solver in the NR variation. +| `CR_correct_rate` | The accuracy of the solver in the CR variation. +| `IR_correct_rate` | The accuracy of the solver in the IR variation. + +## Variants + +| Variant | Notes | +| --- | --- | +| Default: `error-recovery.main` | Default setting. Includes all 1436 examples. | +| `error-recovery.medium` | A balanced subset of 250 examples, 50 from each task. | +| `error-recovery.small` | A balanced subset of 10 examples, 2 from each task (primarily for quick tests). | + +## Token Usage Estimates + +Below is a rough estimate of the total number of tokens consumed by the default variant: + +| Solver | Tokens | +| --- | --- | +| error_recovery/gpt-3.5-turbo-0613 | 4,750,781 | +| error_recovery/gpt-4-0613 | 4,386,440 | + +## Version History + +- v0: Initial version released + +## Contribution statement + +Eval design, implementation, results evaluation, and report were primarily done by Ian McKenzie based on an initial version by Oliver Jaffe. Work done under the guidance of (alphabetically by last-name) Steven Adler, James Aung, and Chan Jun Shern, who provided research input, report revisions, and project management support. diff --git a/evals/elsuite/error_recovery/defaults.py b/evals/elsuite/error_recovery/defaults.py new file mode 100644 index 0000000000..ee5560a08d --- /dev/null +++ b/evals/elsuite/error_recovery/defaults.py @@ -0,0 +1,15 @@ +DEFAULT_TASK_DESCRIPTION = "Solve the given problem, writing your reasoning along the way." + +DEFAULT_MISTAKE_MESSAGE = "There might be a mistake in your reasoning." + +DEFAULT_FINAL_ANSWER_MESSAGE = ( + "Given this reasoning, write your final answer. Only write your final answer, and nothing else." +) + +TASK_SPECIFIC_EXTRACTION_INFO = { + "dyck_languages": "\n\nAnswer with just the end of the sequence, separated by spaces. Do not repeat the part of the sequence given in the question. Only write the sequence of symbols, nothing else.", + "logical_deduction": "\n\nAnswer with the selected single letter indicating your answer, wrapped with parentheses. Do not write anything else.", + "multistep_arithmetic": "\n\nAnswer with a single number.", + "tracking_shuffled_objects": "\n\nAnswer with the selected single letter indicating your answer, wrapped with parentheses. Do not write anything else.", + "word_sorting": "\n\nAnswer with the sorted words, each lower case and separated by a single space.", +} diff --git a/evals/elsuite/error_recovery/eval.py b/evals/elsuite/error_recovery/eval.py new file mode 100644 index 0000000000..89512179fe --- /dev/null +++ b/evals/elsuite/error_recovery/eval.py @@ -0,0 +1,284 @@ +import copy +import random +from dataclasses import dataclass +from typing import Any, List, Literal, Optional, Sequence + +import evals +import evals.metrics +import evals.record +from evals.api import CompletionFn +from evals.elsuite.error_recovery.defaults import ( + DEFAULT_FINAL_ANSWER_MESSAGE, + DEFAULT_MISTAKE_MESSAGE, + DEFAULT_TASK_DESCRIPTION, + TASK_SPECIFIC_EXTRACTION_INFO, +) +from evals.eval import SolverEval +from evals.solvers.solver import Solver +from evals.task_state import Message, TaskState + +# possible Mistake NOTIFiciation POSitions +MistakeNotifPos = Literal["immediate", "end"] + + +@dataclass +class Sample: + question: str + correct_steps: Sequence[str] + incorrect_step: str + target: Any + task: str + num_ground_truth_steps: int + mistake_index: int + + +class ErrorRecovery(SolverEval): + def __init__( + self, + completion_fns: Sequence[CompletionFn], + samples_jsonl: str, + n_samples: Optional[int] = None, + mistake_notification_position: Optional[MistakeNotifPos] = None, + mistake_notification_for_ir_only: bool = False, + mark_as_own_reasoning: bool = True, + final_answer_prompt_role: str = "system", + *args, + **kwargs, + ): + """Evaluate a solver on the error recovery task. + + Args: + completion_fns: The completion functions to evaluate. (should be a single solver) + samples_jsonl: The relative path to the samples jsonl file in evals/registry/data. + n_samples: The number of samples to use. If None, use all samples. + mistake_notification_position: The position of the mistake + notification. Options are "immediate" for right after the provided + reasoning, or "end" for right after the model-generated reasoning. + If None, no mistake notification is added. + mistake_notification_for_ir_only: Whether to only add the mistake notification + for the incorrect reasoning case. If True, the mistake notification is + added for the incorrect reasoning case, and not for the correct reasoning + or no reasoning cases. + mark_as_own_reasoning: Whether to include the sample reasoning as an + 'assistant' or 'user' message. + final_answer_prompt_role: The role to use for the final answer prompt. Should + be either "system" or "user". + """ + super().__init__( + completion_fns=completion_fns, samples_jsonl=samples_jsonl, *args, **kwargs + ) + + self.n_samples = n_samples + self.mistake_notif_pos: Optional[MistakeNotifPos] = mistake_notification_position + self.mistake_notif_ir_only = mistake_notification_for_ir_only + + # there are some issues with passing bools in from extra_eval_params + assert isinstance(mark_as_own_reasoning, bool) + self.mark_as_own_reasoning = mark_as_own_reasoning + + self.final_answer_prompt_role = final_answer_prompt_role + assert self.final_answer_prompt_role in ["system", "user"] + + def eval_sample(self, solver: Solver, sample: Sample, rng: random.Random, extra_logging=None): + task = sample.task + + # Get the baseline with no provided reasoning + nr_task_state = self._get_no_reasoning_task_state(sample) + # only "end" makes sense for 'no reasoning' + nr_notif_pos = "end" if self.mistake_notif_pos == "end" else None + if self.mistake_notif_ir_only: + nr_notif_pos = None + + nr_answer = self._get_answer( + solver=solver, + task_state=nr_task_state, + sample=sample, + mistake_notif_pos=nr_notif_pos, + ) + + # Run with correct reasoning + cr_task_state = self._get_correct_reasoning_task_state(sample) + cr_notif_pos = self.mistake_notif_pos + if self.mistake_notif_ir_only: + cr_notif_pos = None + + cr_answer = self._get_answer( + solver=solver, + task_state=cr_task_state, + sample=sample, + mistake_notif_pos=cr_notif_pos, + ) + + # Run with incorrect reasoning + ir_task_state = self._get_incorrect_reasoning_task_state(sample) + ir_notif_pos = self.mistake_notif_pos + + ir_answer = self._get_answer( + solver=solver, + task_state=ir_task_state, + sample=sample, + mistake_notif_pos=ir_notif_pos, + ) + + assert len(sample.correct_steps) == sample.mistake_index + + metrics = { + "task": task, + "num_ground_truth_steps": sample.num_ground_truth_steps, + "mistake_index": sample.mistake_index, + "target": str(sample.target), # ground truth answer + "mistake_notification_position": self.mistake_notif_pos, + "mistake_notification_for_ir_only": self.mistake_notif_ir_only, + "NR_sampled": nr_answer, + "CR_sampled": cr_answer, + "IR_sampled": ir_answer, + "NR_correct": nr_answer == str(sample.target), + "CR_correct": cr_answer == str(sample.target), + "IR_correct": ir_answer == str(sample.target), + } + evals.record.record_metrics(**metrics) + + def _get_no_reasoning_task_state(self, sample: Sample) -> TaskState: + task_description = DEFAULT_TASK_DESCRIPTION + no_reasoning_messages = [ + Message(role="user", content=sample.question), + ] + no_reasoning_task_state = TaskState( + task_description=task_description, + messages=no_reasoning_messages, + ) + return no_reasoning_task_state + + def _get_correct_reasoning_task_state(self, sample: Sample) -> TaskState: + task_description = DEFAULT_TASK_DESCRIPTION + correct_steps = "\n".join(sample.correct_steps) + reasoning_role = "assistant" if self.mark_as_own_reasoning else "user" + correct_reasoning_messages = [ + Message(role="user", content=sample.question), + Message(role=reasoning_role, content=correct_steps), + ] + correct_reasoning_task_state = TaskState( + task_description=task_description, + messages=correct_reasoning_messages, + ) + return correct_reasoning_task_state + + def _get_incorrect_reasoning_task_state( + self, + sample: Sample, + ) -> TaskState: + task_description = DEFAULT_TASK_DESCRIPTION + correct_steps = "\n".join(sample.correct_steps) + steps_with_incorrect_reasoning = f"{correct_steps}\n{sample.incorrect_step}" + reasoning_role = "assistant" if self.mark_as_own_reasoning else "user" + incorrect_reasoning_messages = [ + Message(role="user", content=sample.question), + Message(role=reasoning_role, content=steps_with_incorrect_reasoning), + ] + + incorrect_reasoning_task_state = TaskState( + task_description=task_description, + messages=incorrect_reasoning_messages, + ) + return incorrect_reasoning_task_state + + def _get_answer( + self, + solver: Solver, + task_state: TaskState, + sample: Sample, + mistake_notif_pos: Optional[MistakeNotifPos], + ) -> str: + """Get a final answer from the solver for a given sample. + + Args: + solver: The solver to use. + task_state: The task state to use. + sample: The Sample being evaluated (relevant for answer extraction). + mistake_notification_position: The position of the mistake notification. + Options are "immediate" for right after the provided reasoning, or "end" for right + after the model-generated reasoning. If None, no mistake notification is added. + + TODO (ian): Work out whether to add mistake notification to 'no reasoning' baseline + """ + mistake_message = Message("user", DEFAULT_MISTAKE_MESSAGE) + if mistake_notif_pos == "immediate": + task_state.messages.append(mistake_message) + + output = solver(task_state=task_state).output + task_state.messages.append(Message("assistant", output)) + + # run solver again if mistake notification is at the end + if mistake_notif_pos == "end": + task_state.messages.append(mistake_message) + output = solver(task_state=task_state).output + task_state.messages.append(Message("assistant", output)) + + answer = self._extract_final_answer(solver=solver, task_state=task_state, sample=sample) + return answer + + def run(self, recorder: evals.record.Recorder): + samples = self.get_samples() + + self.eval_all_samples(recorder, samples) + metrics = recorder.get_metrics() + + NR_correct_rate = len([i for i in metrics if i["NR_correct"]]) / len(metrics) + CR_correct_rate = len([i for i in metrics if i["CR_correct"]]) / len(metrics) + IR_correct_rate = len([i for i in metrics if i["IR_correct"]]) / len(metrics) + + results = { + "NR_correct_rate": NR_correct_rate, + "CR_correct_rate": CR_correct_rate, + "IR_correct_rate": IR_correct_rate, + } + + # Split results per type of task + all_tasks = set([i["task"] for i in metrics]) + for task in all_tasks: + filtered_metrics = [i for i in metrics if i["task"] == task] + NR_correct_rate = len([i for i in filtered_metrics if i["NR_correct"]]) / len( + filtered_metrics + ) + CR_correct_rate = len([i for i in filtered_metrics if i["CR_correct"]]) / len( + filtered_metrics + ) + IR_correct_rate = len([i for i in filtered_metrics if i["IR_correct"]]) / len( + filtered_metrics + ) + + # we use hyphens in the task name so they can be extracted by splitting on underscores + task_string = task.replace("_", "-") + results.update( + { + f"task_{task_string}_NR_correct_rate": NR_correct_rate, + f"task_{task_string}_CR_correct_rate": CR_correct_rate, + f"task_{task_string}_IR_correct_rate": IR_correct_rate, + } + ) + + return results + + def _extract_final_answer(self, solver: Solver, task_state: TaskState, sample: Sample): + """Extract the final answer from the solver output using the same solver.""" + task_state = copy.deepcopy(task_state) + + task_specific_info = TASK_SPECIFIC_EXTRACTION_INFO[sample.task] + final_answer_prompt = DEFAULT_FINAL_ANSWER_MESSAGE + task_specific_info + + task_state.messages.append( + Message(role=self.final_answer_prompt_role, content=final_answer_prompt) + ) + answer = solver(task_state=task_state).output + + return answer + + def get_samples(self) -> List[Sample]: + samples = super().get_samples() + + if self.n_samples is not None: + assert ( + len(samples) >= self.n_samples + ), f"Can't get {self.n_samples} samples from a dataset with {len(samples)} samples" + samples = samples[: self.n_samples] + return [Sample(**sample_dict) for sample_dict in samples] diff --git a/evals/elsuite/error_recovery/scripts/dataset_creation.py b/evals/elsuite/error_recovery/scripts/dataset_creation.py new file mode 100644 index 0000000000..c6c14b2417 --- /dev/null +++ b/evals/elsuite/error_recovery/scripts/dataset_creation.py @@ -0,0 +1,156 @@ +import subprocess +from pathlib import Path + +import matplotlib.pyplot as plt +import pandas as pd + +TASK_PREFIX = { + "dyck_languages": ( + "Given the following sequence of opening and closing brackets, " + "provide the minimal sequence of additional brackets that would " + "balance the original sequence:\n" + ), + "logical_deduction": "", + "multistep_arithmetic": "", + "tracking_shuffled_objects": "", + "word_sorting": "Sort the following list of words alphabetically:\n", +} + + +def main(): + data = clone_and_load_data() + # plot_hist(data) + pos_data = create_positive_examples(data) + # don't use examples where last step is mistake + pos_data = pos_data[pos_data["mistake_index"] < pos_data["num_steps"] - 1] + + # only save a subset of the columns + pos_data = pos_data[ + ["input", "correct_steps", "incorrect_step", "mistake_index", "num_steps", "target", "task"] + ] + pos_data.rename( + columns={ + "input": "question", + "num_steps": "num_ground_truth_steps", + }, + inplace=True, + ) + + # save data + save_path = Path("evals/registry/data/error_recovery/main.jsonl") + pos_data.to_json(save_path, orient="records", lines=True) + + small_save_path = Path("evals/registry/data/error_recovery/small.jsonl") + # get small dataset with two examples from each task + small_data = create_data_subset(pos_data, examples_per_task=2) + small_data.to_json(small_save_path, orient="records", lines=True) + + medium_save_path = Path("evals/registry/data/error_recovery/medium.jsonl") + # get medium dataset with 50 examples from each task + medium_data = create_data_subset(pos_data, examples_per_task=50) + medium_data.to_json(medium_save_path, orient="records", lines=True) + + +def create_data_subset(data: pd.DataFrame, examples_per_task: int) -> pd.DataFrame: + # get small dataset with a subset of examples from each task + small_data = pd.DataFrame() + for task in data["task"].unique(): + task_data = data[data["task"] == task] + task_subset = task_data[:examples_per_task] + if len(task_subset) < examples_per_task: + raise ValueError( + f"Task {task} has only {len(task_subset)} examples, less than {examples_per_task}" + ) + small_data = pd.concat((small_data, task_subset)) + return small_data + + +def create_positive_examples(data: pd.DataFrame) -> pd.DataFrame: + has_incorrect_reasoning = ~data["mistake_index"].isnull() + has_incorrect_answer = data["target"] != data["answer"] + positive_condition = has_incorrect_reasoning & has_incorrect_answer + + positive_data = data.copy() + positive_data = positive_data[positive_condition].reset_index() + positive_data["label"] = "positive" + positive_data["correct_steps"] = positive_data.apply( + lambda row: row["steps"][: int(row["mistake_index"])], axis=1 + ) + positive_data["incorrect_step"] = positive_data.apply( + lambda row: row["steps"][int(row["mistake_index"])], axis=1 + ) + return positive_data + + +def create_negative_examples(data: pd.DataFrame) -> pd.DataFrame: + """Create a dataset of examples with correct reasoning and answer. + + The 'negative' naming is a bit misleading, but these are the examples + we don't use. + TODO (ian): think about renaming + """ + has_correct_reasoning = data["mistake_index"].isnull() + has_correct_answer = data["target"] == data["answer"] + negative_condition = has_correct_reasoning & has_correct_answer + negative_data = data.copy() + negative_data = negative_data[negative_condition].reset_index() + negative_data["label"] = "negative" + negative_data["correct_steps"] = negative_data["steps"] + negative_data["incorrect_step"] = "" + return negative_data + + +def clone_and_load_data(): + clone_dir = Path("/tmp/BIG-Bench-Mistake") + maybe_clone_repo(clone_dir) + + data = pd.DataFrame() + for jsonl_file in clone_dir.glob("*.jsonl"): + file_data = pd.read_json(jsonl_file, lines=True) + + # Manually append task description to datasets missing one + task = jsonl_file.stem + prefix = TASK_PREFIX[task] + file_data["input"] = prefix + file_data["input"] + file_data["task"] = task + + data = pd.concat((data, file_data)) + + data["num_steps"] = data["steps"].apply(lambda x: len(x)) + return data + + +def maybe_clone_repo(clone_dir): + if not clone_dir.exists(): + subprocess.run( + ["git", "clone", "https://github.com/WHGTyen/BIG-Bench-Mistake.git", str(clone_dir)] + ) + + +def plot_hist(data): + data["num_steps"].hist(bins=max(data["num_steps"])) + plt.show() + + +def print_example(): + data = clone_and_load_data() + # printing some examples + subset_data = create_positive_examples(data) + # subset_data = create_negative_examples(data) + # # print one negative object swapping example + # neg_example = neg_data[neg_data["task"] == "tracking_shuffled_objects"].iloc[0] + # # print one negative dyck example + # neg_example = neg_data[neg_data["task"] == "dyck_languages"].iloc[0] + # neg_example = neg_data[neg_data["task"] == "logical_deduction"].iloc[0] + example = subset_data[subset_data["task"] == "multistep_arithmetic"].iloc[1] + print(f"INPUT ======\n{example['input']}") + steps = "\n".join(example["steps"]) + print(f"STEPS ======\n{steps}") + print(f"MISTAKE INDEX ======\n{example['mistake_index']}") + print(f"ANSWER ======\n{example['answer']}") + print(f"TARGET ======\n{example['target']}") + print("========") + + +if __name__ == "__main__": + main() diff --git a/evals/elsuite/error_recovery/scripts/make_plots.py b/evals/elsuite/error_recovery/scripts/make_plots.py new file mode 100644 index 0000000000..0d2dcfaa43 --- /dev/null +++ b/evals/elsuite/error_recovery/scripts/make_plots.py @@ -0,0 +1,597 @@ +import argparse +import os +from pathlib import Path +from typing import Optional + +import numpy as np +import pandas as pd +from matplotlib import pyplot as plt + +from evals.utils import log_utils + +# MODEL_NAMES = { +# "error_recovery/gpt-4-0613": "GPT-4", +# "generation/hhh/gpt-4-base": "GPT-4 Base", +# "error_recovery/gpt-3.5-turbo-0613": "GPT-3.5", +# # "gpt-4-base": "gpt-4-base", +# } +# using model checkpoint names +MODEL_NAMES = { + "error_recovery/gpt-4-0613": "gpt-4-0613", + "generation/hhh/gpt-4-base": "gpt-4-base", + "error_recovery/gpt-3.5-turbo-0613": "gpt-3.5-turbo-0613", + # "generation/direct/llama-2-13b-chat": "llama-2-13b-chat", + "generation/direct/llama-2-70b-chat": "llama-2-70b-chat", + "generation/direct/mixtral-8x7b-instruct": "mixtral-8x7b-instruct", + "generation/direct/gemini-pro": "gemini-pro-1.0", +} + +MODEL_COLOR_MAP = { + "error_recovery/gpt-4-0613": "purple", + "generation/hhh/gpt-4-base": "plum", + "error_recovery/gpt-3.5-turbo-0613": "g", + # "generation/direct/llama-2-13b-chat": "wheat", + "generation/direct/llama-2-70b-chat": "orange", + "generation/direct/mixtral-8x7b-instruct": "red", + "generation/direct/gemini-pro": "cornflowerblue", +} +VARIATION_NAMES = { + "nr_name": "From Scratch", + "cr_name": "Correct Basis", + "ir_name": "Incorrect Basis", +} + +VARIATION_COLOR_MAP = { + "nr_name": "blue", + "cr_name": "green", + "ir_name": "red", +} + +TASK_NAMES = { + "word_sorting": "Word Sorting", + "tracking_shuffled_objects": "Tracking Shuffled Objects", + "logical_deduction": "Logical Deduction", + "multistep_arithmetic": "Multi-Step Arithmetic", + "dyck_languages": "Dyck Languages", +} + + +def maybe_show(fig): + if DISPLAY: + fig.show() + plt.close(fig) + + +def extract_results(datadir: Path) -> pd.DataFrame: + df_rows = [] + for path, results in log_utils.get_final_results_from_dir(datadir).items(): + spec = log_utils.extract_spec(path) + model = spec["completion_fns"][0] + base_eval = spec["base_eval"] + df_rows.append( + { + "model": model, + "base_eval": base_eval, + **results, + } + ) + df = pd.DataFrame(df_rows) + return df + + +def extract_metrics(datadir: Path) -> pd.DataFrame: + df_rows = [] + for path, results in sorted(list(log_utils.get_final_results_from_dir(datadir).items())): + spec = log_utils.extract_spec(path) + solver = spec["completion_fns"][0] + for res in log_utils.extract_individual_results(path): + df_rows.append( + { + "solver": solver, + **res, + } + ) + df = pd.DataFrame(df_rows) + # Sort rows + # print(df.columns) + df.sort_values(by=["solver", "task"], inplace=True) + return df + + +def get_all_tasks(results_df: pd.DataFrame) -> list[str]: + # Find all types of tasks + all_tasks = [] + for i in results_df.columns: + if i.startswith("task_") and i.endswith("_CR_correct_rate"): + all_tasks.append(i) + + # Make ordering consistent + all_tasks.sort() + return all_tasks + + +def get_all_tasks_renamed(results_df: pd.DataFrame) -> list[str]: + all_tasks = get_all_tasks(results_df) + all_tasks_renamed = [i.split("task_")[1].split("_CR_correct_rate")[0] for i in all_tasks] + # replace hyphens with underscores + all_tasks_renamed = [i.replace("-", "_") for i in all_tasks_renamed] + return all_tasks_renamed + + +def get_unique_models(results_df: pd.DataFrame) -> list[str]: + models = results_df["model"].to_list() + # TODO: work out how to order a variable set of models + if set(models) == set(MODEL_NAMES.keys()): + unique_models = list(MODEL_NAMES.keys()) + else: + unique_models = sorted(list(set(models)), reverse=True) + return unique_models + + +def get_cleaned_model_name(model: str) -> str: + return model.replace("/", "_") + + +def corrects_to_accuracy_and_sem(corrects: pd.Series): + accuracy = corrects.mean() + sem = corrects.sem() + return accuracy, sem + + +def annotate_axes(ax, errors: Optional[pd.DataFrame]): + """Annotate each bar in the plot with its value""" + ABOVE_OFFSET = 0.01 + BELOW_OFFSET = 0.1 + if errors is not None: + # This gets it into a shape to match the order of the patch objects. + # I don't have a principled reason to transpose, this is just what works. + error_values = errors.to_numpy().T.flatten() + + for i, p in enumerate(ax.patches): + # patch objects aren't typed correctly + p_height = p.get_height() # type: ignore + p_x = p.get_x() # type: ignore + p_width = p.get_width() # type: ignore + # Calculate the label position + x = p_x + p_width / 2 + if errors is not None: + error = error_values[i] + else: + error = 0 + + if p_height > 0: + y = p_height + error + ABOVE_OFFSET + else: + y = p_height - error - BELOW_OFFSET + + # Annotate the bar with its value + # ax.annotate(f"{p_height:.2f}\n±{error:.2f}", (x, y), ha="center", va="bottom") + ax.annotate(f"{p_height:.2f}", (x, y), ha="center", va="bottom") + + +def corrects_to_performance_loss_and_error(CR_corrects: pd.Series, IR_corrects: pd.Series): + CR_correct_rate = CR_corrects.mean() + IR_correct_rate = IR_corrects.mean() + + performance_recovered = IR_correct_rate / CR_correct_rate + performance_loss = 1 - performance_recovered + # propagate error from CR_corrects and IR_corrects to performance_loss + CR_correct_rate_sem = CR_corrects.sem() + IR_correct_rate_sem = IR_corrects.sem() + assert isinstance(CR_correct_rate_sem, float) + assert isinstance(IR_correct_rate_sem, float) + # using the formula for error propagation for a ratio from + # https://en.wikipedia.org/wiki/Propagation_of_uncertainty#Example_formulae + # (assuming errors in CR and IR are independent). + # NOTE: the 1 in performance_loss is a constant, + # so doesn't affect the uncertainty bounds on the ratio. + CR_term = (CR_correct_rate_sem / CR_correct_rate) ** 2 + IR_term = (IR_correct_rate_sem / IR_correct_rate) ** 2 + performance_loss_error = abs(performance_recovered) * ((CR_term + IR_term) ** 0.5) + print(f"Performance loss: {performance_loss:.2f} ± {performance_loss_error:.2f}") + return performance_loss, performance_loss_error + + +def accuracy_by_task(metrics_df, results_df: pd.DataFrame, out_dir: Path): + all_tasks = get_all_tasks(results_df) + unique_models = get_unique_models(results_df) + all_tasks_renamed = get_all_tasks_renamed(results_df) + + # Plot results separately for each model + for model in unique_models: + plot_accuracy_by_task(model, metrics_df, all_tasks, all_tasks_renamed, out_dir) + + +def accuracy_by_model_dfs(metrics_df, results_df: pd.DataFrame): + unique_models = get_unique_models(results_df) + accuracies = {} + sems = {} + for model in unique_models: + pass + # for all tasks + model_mask = metrics_df.solver == model + model_CR_corrects = metrics_df[model_mask]["CR_correct"] + model_IR_corrects = metrics_df[model_mask]["IR_correct"] + model_NR_corrects = metrics_df[model_mask]["NR_correct"] + + model_CR_accuracy, model_CR_sem = corrects_to_accuracy_and_sem(model_CR_corrects) + model_IR_accuracy, model_IR_sem = corrects_to_accuracy_and_sem(model_IR_corrects) + model_NR_accuracy, model_NR_sem = corrects_to_accuracy_and_sem(model_NR_corrects) + + pretty_model_name = MODEL_NAMES[model] + sems[pretty_model_name] = { + "nr_name": model_NR_sem, + "cr_name": model_CR_sem, + "ir_name": model_IR_sem, + } + accuracies[pretty_model_name] = { + "nr_name": model_NR_accuracy, + "cr_name": model_CR_accuracy, + "ir_name": model_IR_accuracy, + } + + order = ["nr_name", "cr_name", "ir_name"] + plot_df = pd.DataFrame(accuracies) + plot_df = plot_df.reindex(order) + sems_df = pd.DataFrame(sems) + sems_df = sems_df.reindex(order) + return plot_df, sems_df + + +def accuracy_by_model(metrics_df, results_df: pd.DataFrame, out_dir: Path): + unique_models = get_unique_models(results_df) + plot_df, sems_df = accuracy_by_model_dfs(metrics_df, results_df) + + fig, ax = plt.subplots(figsize=(12, 6), constrained_layout=True) + colors = [MODEL_COLOR_MAP[model] for model in unique_models] + plot_df.index = list(VARIATION_NAMES.values()) + sems_df.index = list(VARIATION_NAMES.values()) + ax = plot_df.plot.bar( + rot=0, + yerr=sems_df, + capsize=4, + ax=ax, + width=0.8, + color=colors, + ) + annotate_axes(ax, sems_df) + ax.set_ylim(top=1.0) + ax.set_xlabel("Reasoning variations") + ax.set_ylabel("Accuracy") + ax.set_title("Accuracy for each variation (higher is better)") + + outpath = os.path.join(out_dir, "accuracy_by_model.png") + fig.savefig(outpath) + maybe_show(fig) + + +def accuracy_by_model_and_reasoning( + own_metrics_df: pd.DataFrame, + own_results_df: pd.DataFrame, + other_metrics_df: pd.DataFrame, + other_results_df: pd.DataFrame, + out_dir: Path, +): + own_plot_df, own_sems_df = accuracy_by_model_dfs(own_metrics_df, own_results_df) + other_plot_df, other_sems_df = accuracy_by_model_dfs(other_metrics_df, other_results_df) + # drop the no reasoning baseline + own_plot_df = own_plot_df.drop("nr_name", axis=0) + own_sems_df = own_sems_df.drop("nr_name", axis=0) + other_plot_df = other_plot_df.drop("nr_name", axis=0) + other_sems_df = other_sems_df.drop("nr_name", axis=0) + + own_plot_df = own_plot_df.T + own_sems_df = own_sems_df.T + other_plot_df = other_plot_df.T + other_sems_df = other_sems_df.T + models = own_plot_df.index # e.g., ["No reasoning (baseline)", "Correct reasoning", ...] + n_models = len(models) + bar_width = 0.35 # width of the bars + n_variations = len(own_plot_df.columns) + assert n_variations == len(other_plot_df.columns) + group_width = 0.8 # Total width for one group of bars + bar_width = group_width / (n_variations * 2) # Width of one bar + + # Create figure and axis + fig, ax = plt.subplots(figsize=(12, 8), constrained_layout=True) + + # Set position of bar on X axis + ind = np.arange(n_models) # the x locations for the groups + + colors = [VARIATION_COLOR_MAP[variation] for variation in own_plot_df.columns] + VARIATION_OFFSET = 0.03 + for i, variation in enumerate(own_plot_df.columns): + # Position of bars for this model + # bars for a given model are grouped together, and then within that group, the bars for each variation are grouped + r1 = ind + i * VARIATION_OFFSET + i * (n_variations * bar_width) + r2 = [x + bar_width for x in r1] + + ax.bar( + r1, + own_plot_df[variation], + width=bar_width, + yerr=own_sems_df[variation], + capsize=5, + label=f"{VARIATION_NAMES[variation]} ('assistant' message)", + color=colors[i], + # add outline to bars + edgecolor="black", + ) + ax.bar( + r2, + other_plot_df[variation], + width=bar_width, + yerr=other_sems_df[variation], + capsize=5, + label=f"{VARIATION_NAMES[variation]} ('user' message)", + hatch="//", + color=colors[i], + edgecolor="black", + ) + + for j, model in enumerate(models): + x_own = r1[j] + x_other = r2[j] + y1 = own_plot_df.loc[model, variation] + y2 = other_plot_df.loc[model, variation] + y1_err = own_sems_df.loc[model, variation] + y2_err = other_sems_df.loc[model, variation] + ax.text(x_own, y1 + y1_err, f"{y1:.2f}", ha="center", va="bottom") + ax.text(x_other, y2 + y2_err, f"{y2:.2f}", ha="center", va="bottom") + + # Add xticks on the middle of the group bars + xtick_positions = ind + bar_width * n_variations + (VARIATION_OFFSET - bar_width) / 2 + ax.set_xticks(xtick_positions) + ax.set_xticklabels(models) + + # Create legend & Show graphic + ax.set_xlabel("Model") + ax.set_ylabel("Accuracy") + ax.set_ylim(top=1.0) + ax.legend() + ax.set_title("Accuracy for each variation (higher is better)") + outpath = os.path.join(out_dir, "accuracy_by_category_and_reasoning.png") + fig.savefig(outpath) + maybe_show(fig) + + +def plot_accuracy_by_steps_all(metrics_df, results_df, out_dir): + """ + Create plots of accuracy of: + - num_steps - mistake_index + - mistake_index / num_steps + """ + get_all_tasks(results_df) + all_tasks_renamed = get_all_tasks_renamed(results_df) + all_models = get_unique_models(results_df) + # one plot per task, one subplot per model + for task in all_tasks_renamed: + fig, axs = plt.subplots( + 1, len(all_models), figsize=(15, 6), constrained_layout=True, squeeze=False + ) + axs = axs.flatten() + for ax, model in zip(axs, all_models): + task_model_df = metrics_df[(metrics_df.solver == model) & (metrics_df.task == task)] + plot_accuracy_by_steps(task_model_df, task, model, ax) + # only put legend on last plot + final_ax = axs[-1] + final_ax.legend(loc="upper center") + outpath = os.path.join(out_dir, f"results-split-by-steps_{task}.png") + fig.suptitle(f"Accuracy by steps for {TASK_NAMES[task]} (higher is better)") + fig.savefig(outpath) + maybe_show(fig) + + +def plot_accuracy_by_steps(df, task, model, ax): + df["steps_diff"] = df["num_ground_truth_steps"] - df["mistake_index"] + + # due to the way pandas works, we have to group, then filter, then regroup + grouped_df = df.groupby("steps_diff") + + MIN_SAMPLES = 10 + filtered_groups = grouped_df.filter(lambda x: len(x) >= MIN_SAMPLES) + + # Now, re-group the filtered DataFrame by 'steps_diff' again and calculate the mean + plot_df = filtered_groups.groupby("steps_diff")[ + ["NR_correct", "CR_correct", "IR_correct"] + ].mean() + colors = [VARIATION_COLOR_MAP[variation] for variation in VARIATION_NAMES.keys()] + + # change the names of the columns to be more readable before plotting + plot_df.columns = list(VARIATION_NAMES.values()) + # now plot the three accuracies against steps_diff + assert isinstance(plot_df, pd.DataFrame) + ax = plot_df.plot(color=colors, ax=ax, legend=False) + ax.set_xlabel("Steps beyond mistake") + ax.set_ylabel("Accuracy") + ax.set_ylim(0, 1.1) + # ax.set_title(f"{MODEL_NAMES[model]} | {TASK_NAMES[task]} (higher is better)") + ax.set_title(f"{MODEL_NAMES[model]}") + # plt.tight_layout() + return ax + + +def plot_accuracy_by_task(model, metrics_df, all_tasks, all_tasks_renamed, out_dir): + all_tasks_pretty = [TASK_NAMES[i] for i in all_tasks_renamed] + accuracies = {"nr_name": [], "cr_name": [], "ir_name": []} + all_sems = [] + # for all tasks + model_mask = metrics_df.solver == model + + # and split by task type + for task in all_tasks_renamed: + + task_mask = metrics_df.task == task + CR_corrects = metrics_df[model_mask & task_mask]["CR_correct"] + IR_corrects = metrics_df[model_mask & task_mask]["IR_correct"] + NR_corrects = metrics_df[model_mask & task_mask]["NR_correct"] + + CR_accuracy, CR_sem = corrects_to_accuracy_and_sem(CR_corrects) + IR_accuracy, IR_sem = corrects_to_accuracy_and_sem(IR_corrects) + NR_accuracy, NR_sem = corrects_to_accuracy_and_sem(NR_corrects) + + accuracies["nr_name"].append(NR_accuracy) + accuracies["cr_name"].append(CR_accuracy) + accuracies["ir_name"].append(IR_accuracy) + + sems = [NR_sem, CR_sem, IR_sem] + all_sems.append(sems) + + sems_df = pd.DataFrame( + all_sems, + index=all_tasks_pretty, + columns=["nr_name", "cr_name", "ir_name"], + ) + + plot_df = pd.DataFrame(accuracies, index=all_tasks_pretty) + + fig, ax = plt.subplots(figsize=(15, 6), constrained_layout=True) + colors = [VARIATION_COLOR_MAP[variation] for variation in plot_df.columns] + plot_df.columns = list(VARIATION_NAMES.values()) + ax = plot_df.plot.bar(rot=0, color=colors, yerr=sems_df, capsize=4, ax=ax, width=0.8) + annotate_axes(ax, sems_df) + + # Shrink current axis by 20% to make room for the legend + box = ax.get_position() + ax.set_position((box.x0, box.y0, box.width * 0.8, box.height)) + # Place the legend outside the plot + ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) + ax.set_ylim(top=1.1) + ax.set_xlabel("Task type") + ax.set_ylabel("Accuracy") + ax.set_title(f"{MODEL_NAMES[model]} (higher is better)") + outpath = os.path.join(out_dir, f"results-split-by-task_{get_cleaned_model_name(model)}.png") + fig.savefig(outpath) + maybe_show(fig) + + +def performance_loss_per_task(metrics_df: pd.DataFrame, results_df: pd.DataFrame, out_dir: Path): + # Plot performance lost for each model + unique_models = get_unique_models(results_df) + get_all_tasks(results_df) + all_tasks_renamed = get_all_tasks_renamed(results_df) + all_tasks_pretty = [TASK_NAMES[i] for i in all_tasks_renamed] + + all_metrics = {} + all_errors = {} + for model in unique_models: + metrics = [] + errors = [] + for task in all_tasks_renamed: + model_mask = metrics_df.solver == model + task_mask = metrics_df.task == task + CR_corrects = metrics_df[model_mask & task_mask]["CR_correct"] + IR_corrects = metrics_df[model_mask & task_mask]["IR_correct"] + + performance_loss, performance_loss_error = corrects_to_performance_loss_and_error( + CR_corrects, IR_corrects + ) + metrics.append(performance_loss) + errors.append(performance_loss_error) + + pretty_model_name = MODEL_NAMES[model] + all_metrics[pretty_model_name] = metrics + all_errors[pretty_model_name] = errors + + fig, ax = plt.subplots(figsize=(20, 6), constrained_layout=True) + plot_df = pd.DataFrame(all_metrics, index=all_tasks_pretty) + errs_df = pd.DataFrame(all_errors, index=all_tasks_pretty) + colors = [MODEL_COLOR_MAP[model] for model in unique_models] + ax = plot_df.plot.bar(rot=0.0, color=colors, ax=ax, width=0.8, yerr=errs_df, capsize=4) + annotate_axes(ax, errs_df) + # Shrink current axis by 20% to make room for the legend + box = ax.get_position() + ax.set_position((box.x0, box.y0, box.width * 0.8, box.height)) + ax.set_ylim(bottom=-1, top=1.1) + ax.legend() + ax.axhline(0, 0, 1, color="black", linestyle="-") + ax.set_title("Performance loss per task (lower is better)") + ax.set_xlabel("Task type") + ax.set_ylabel("Performance loss") + + outpath = os.path.join(out_dir, "results_split_by_model.png") + fig.savefig(outpath) + maybe_show(fig) + + +def performance_loss_per_model(metrics_df: pd.DataFrame, results_df: pd.DataFrame, out_dir: Path): + unique_models = get_unique_models(results_df) + + metrics = {} + errors = {} + for model in unique_models: + model_mask = metrics_df.solver == model + + CR_corrects = metrics_df[model_mask]["CR_correct"] + IR_corrects = metrics_df[model_mask]["IR_correct"] + + performance_loss, performance_loss_error = corrects_to_performance_loss_and_error( + CR_corrects, IR_corrects + ) + + pretty_model_name = MODEL_NAMES[model] + metrics[pretty_model_name] = performance_loss + errors[pretty_model_name] = performance_loss_error + + fig, ax = plt.subplots(figsize=(10, 6), constrained_layout=True) + plot_df = pd.DataFrame(metrics, index=[0]) + errs_df = pd.DataFrame(errors, index=[0]) + colors = [MODEL_COLOR_MAP[model] for model in unique_models] + ax = plot_df.plot.bar(rot=0, color=colors, ax=ax, width=0.8, yerr=errs_df, capsize=4) + annotate_axes(ax, errs_df) + # Shrink current axis by 20% to make room for the legend + box = ax.get_position() + ax.set_position((box.x0, box.y0, box.width * 0.8, box.height)) + # Place the legend outside the plot + ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) + ax.set_xticklabels([]) + ax.set_xticks([]) + ax.set_ylabel("Performance loss") + ax.set_ylim(top=1.1) + ax.set_title("Average performance loss per model (lower is better)") + outpath = os.path.join(out_dir, "headline_results.png") + fig.savefig(outpath) + maybe_show(fig) + + +def main(): + parser = argparse.ArgumentParser() + # DEBUG: hacking together own_reasoning and other_reasoning plots + parser.add_argument( + "--log_dir", + "-d", + type=str, + required=True, + help="Path to log dir with primary results (if supplementary_dir is provided, this is should be 'own' reasoning)", + ) + parser.add_argument( + "--supplementary_dir", + "-s", + type=str, + help="Optional supplementary log dir with 'other' reasoning results", + ) + parser.add_argument("--out_dir", "-o", type=str, required=True) + args = parser.parse_args() + log_dir = Path(args.log_dir) + out_dir = Path(args.out_dir) + out_dir.mkdir(exist_ok=True, parents=True) + + metrics_df = extract_metrics(log_dir) + results_df = extract_results(log_dir) + if args.supplementary_dir: + other_log_dir = Path(args.supplementary_dir) + other_metrics_df = extract_metrics(other_log_dir) + other_results_df = extract_results(other_log_dir) + accuracy_by_model_and_reasoning( + metrics_df, results_df, other_metrics_df, other_results_df, out_dir + ) + accuracy_by_task(metrics_df, results_df, out_dir) + accuracy_by_model(metrics_df, results_df, out_dir) + performance_loss_per_task(metrics_df, results_df, out_dir) + performance_loss_per_model(metrics_df, results_df, out_dir) + plot_accuracy_by_steps_all(metrics_df, results_df, out_dir) + + +if __name__ == "__main__": + DISPLAY = False + main() diff --git a/evals/elsuite/error_recovery/scripts/run_experiments.sh b/evals/elsuite/error_recovery/scripts/run_experiments.sh new file mode 100755 index 0000000000..36f51faad4 --- /dev/null +++ b/evals/elsuite/error_recovery/scripts/run_experiments.sh @@ -0,0 +1,44 @@ +#!/bin/bash +logdir=./logs +outdir=./outputs + +timestamp=$(date +%Y%m%d_%H%M%S) +logpathbase=$logdir/$timestamp +outpathbase=$outdir/$timestamp +SPLIT=main + +mkdir -p ${logpathbase} + +export EVALS_THREADS=250 +echo Running full experiments and logging to $logpathbase + +declare -a SOLVERS=( + error_recovery/gpt-3.5-turbo-0613 + error_recovery/gpt-4-0613 + generation/hhh/gpt-4-base +) + +# OWN REASONING VARIANT +for solver in "${SOLVERS[@]}" +do + log_name=${SPLIT}_${solver//\//-}_own-reasoning + + oaieval $solver error-recovery.$SPLIT \ + --extra_eval_params final_answer_prompt_role=system \ + --record_path "$logpathbase/$log_name.log" +done + +# OTHER REASONING VARIANT +for solver in "${SOLVERS[@]}" +do + log_name=${SPLIT}_${solver//\//-}_other-reasoning + + oaieval $solver error-recovery.$SPLIT.other-reasoning \ + --extra_eval_params final_answer_prompt_role=system \ + --record_path "$logpathbase/$log_name.log" +done + +echo Producing plots, outputs to $outpathbase + +mkdir -p ${outpathbase} +python make_plots.py --log_dir ${logpathbase} --out_dir $outpathbase diff --git a/evals/elsuite/function_deduction/README.md b/evals/elsuite/function_deduction/README.md new file mode 100644 index 0000000000..924b4e47fb --- /dev/null +++ b/evals/elsuite/function_deduction/README.md @@ -0,0 +1,91 @@ +# Function Deduction + +This eval evaluates how well a model can refine a hypothesis according to new evidence and how well it chooses to gather new information. + +In Function Deduction: + +- There is a secret mathematical function that maps an integer to another integer. +- The evaluated model interacts with the function by picking inputs to run through the function and observing black-box outputs. +- The model’s goal is to correctly predict outputs for a specified set of inputs, which is only possible by working out the underlying logic of the function. + +![fd](https://github.com/openai/policy-research-evals/assets/129281094/6c41be74-8237-4bb3-b0fc-13454c20389c) + +## Usage + +Run with: + +``` +oaieval function_deduction +``` + +We suggest using `function_deduction/cot/gpt-4-32k` or `function_deduction/cot/gpt-3.5-turbo-16k` as default choices for `` + +See `evals/registry/solvers/function_deduction.yaml` for a full list of recommended Solvers. + +For more examples of running this eval, see `scripts/run_experiments.sh` + +## Dataset + +The dataset are 100 int -> int functions of varying difficulty. Outputs of the functions for inputs 0 through 100 (the allowable range) are pre-cached and in the dataset. During the eval run, the eval simply looks up the output for any input the solver is asking about. + +## Evaluation Process + +Samples from the dataset are evaluated one-by-one. + +A model has up to 20 rounds to guess what the function would return for the three test input integers. + +In each round the model can either: + +- Ask what the function returns for a given integer (other than the test inputs) *or* +- Make a guess for the values that would be returned by the function for the test inputs. + +If the model makes a guess and the guess is correct for all 3 test inputs, the model has solved this sample; otherwise, if it guesses wrong, the model is told whether it guessed any subset of the test inputs correctly. + +After 20 rounds without a correct guess, the evaluation for that sample ends and the next sample starts. Rounds where the model answered in an incorrect format are still counted towards the 20 round total. + +## Task State + +| Key | Value | +| --- | --- | +| `task_description` | I have an important task for you. Instructions below.

I have a deterministic function `def calc(x: int) -> int:` that accepts an integer in range [0, 100] as an argument and returns an integer.

I want you to correctly determine the result of this function for each of the following inputs: {inputs}

At each turn you must select one of the following 2 options:
Option (A): Gather useful information. Ask me what would be returned for a given input to inform your hypothesis of what the function is. \
This must be an integer in range [0, 100], other than {inputs}.
Option (B): State your guess for the values returned by the function for inputs {inputs}. Your guess must include all three numbers, in this order. \
If your guess is correct, you win. If not, the task will continue.

You have {n_rounds} turns to determine the correct answer, else you lose.

Your goal is to determine the correct answer in as few turns as possible, so choose your actions carefully!
Your final score will be:
* If you make the correct guess: {n_rounds} - number_of_rounds_played
* If you fail to make the correct guess: -{n_rounds}
You should try to score as high as possible.

You should respond either with a single integer in range 0-100 if you decide to take option A or with three integers if you make the final guess (i.e. use option B).
Don't say anything more except for either an integer or 3 integers. | +| `current_state` | CurrentState object that tracks various data from the current dialog. | + +## Metrics + +The below are the key metrics of this eval: + +| Metric | Interpretation | +| --- | --- | +| `adjusted_avg_score` | Combination metric of the below 2 metrics. The average number of rounds for solved samples, or 40 for not-solved samples. (lower is better) | +| `solved_ratio` | The percentage of solved samples (higher is better) | +| `avg_success_rounds` | The average number of rounds for solved samples (lower is better) | + +## Variants + +| Variant | Notes | +| --- | --- | +| Default: `function_deduction.easy` | Default setting as described above. 1 trial per sample | +| `function_deduction.easy.long` | 10 trials per sample | +| `function_deduction.easy.dev5` | Dev set with only 5 samples | +| `function_deduction.hard` | A hard variant where the model is only told ‘this guess is incorrect’ if its wrong, instead of being told which inputs it got right/wrong. | +| `function_deduction.hard.dev5` | Dev set with only 5 samples | + +## Token Usage Estimates + +Below is a rough estimate of the total number of tokens consumed by the default variant: + +| Solver | Tokens | +| --- | --- | +| function_deduction/gpt-4-base | 3 840 000 | +| gpt-4-32k | 880 000 | +| gpt-3.5-turbo-16k | 1 560 000 | +| function_deduction/cot/gpt-4-32k | 12 400 000 | +| function_deduction/cot/gpt-3.5-turbo-16k | 13 230 000 | + +## Version History + +- v0: Initial version released + +## Contribution statement + +Eval design, implementation, and results evaluation were primarily conducted by Jan Betley with contributions from Andrei Alexandru. Report by James Aung. Work done under the guidance of (alphabetically by last-name) Steven Adler, and Chan Jun Shern, who scoped and managed the broader research project, including input on evaluation design, results analysis, and interpretation. diff --git a/evals/elsuite/function_deduction/baselines.py b/evals/elsuite/function_deduction/baselines.py new file mode 100644 index 0000000000..3a81624e03 --- /dev/null +++ b/evals/elsuite/function_deduction/baselines.py @@ -0,0 +1,133 @@ +import logging +import math +from collections import Counter +from pathlib import Path + +import numpy as np +from scipy.stats import entropy + +from evals.data import get_jsonl +from evals.elsuite.function_deduction.eval import CurrentState, Sample +from evals.registry import Registry +from evals.solvers.solver import Solver, SolverResult +from evals.task_state import TaskState + + +class AverageBaseline(Solver): + """ + For given test inputs (x, y, z): + * Ask about values of (x-1, x+1, y-1, y+1, z-1, z+1) + * Make three guesses: round/floor/ceil of average values for neighboring numbers + If didn't succeed in 9 rounds (6x ask 3x guess) - surrender. + + Note: This algorithm fails on the edge cases where, for any of the inputs i: + - i-1 or i+1 is out of range + - i-1 or i+1 are part of the test inputs + In this scenario, the algorithm will fail at the _get_guess stage and skip the guessing. + """ + + def __init__(self, registry=None): + pass + + def _solve(self, task_state: TaskState): + cs: CurrentState = task_state.current_state + + assert len(cs.test_inputs) == 3, "AverageBaseline assumes 3 test inputs" + + if cs.round_ix < 6: + response = self._get_ask(cs.test_inputs, cs.round_ix) + elif 6 <= cs.round_ix < 9: + response = self._get_guess(cs.test_inputs, cs.known_values, cs.round_ix - 6) + else: + response = "I've run out of ideas sorry :(" + return SolverResult(response) + + def _get_guess(self, test_inputs, known_values: dict[int, int], guess_round_ix) -> str: + known_values = { + x: y for x, y in known_values.items() if x - 1 in test_inputs or x + 1 in test_inputs + } + + pairs = [[], [], []] + for i, test_input in enumerate(test_inputs): + try: + lower = known_values[test_input - 1] + higher = known_values[test_input + 1] + except KeyError: + return "Unfortunately I don't have enough data to make a guess, will pass." + pairs[i] = [lower, higher] + + funcs = [round, math.floor, math.ceil] + func = funcs[guess_round_ix] + vals = [func((pair[0] + pair[1]) / 2) for pair in pairs] + return " ".join([str(x) for x in vals]) + + def _get_ask(self, test_inputs, round_ix) -> str: + queries = [] + for x in test_inputs: + queries.append(x - 1) + queries.append(x + 1) + + ask = queries[round_ix] + if ask in test_inputs or ask < 0 or ask > 100: + logging.warning( + f"Invalid query on inputs {test_inputs}: {ask}. AverageBaseline algorithm will fail." + ) + return str(ask) + + +class FullKnowledge(Solver): + """Assuming solver knows all the samples, how well would it perform? + + Two modes - "random", where it selects random integer when asking, + and "best" where it selects the best integer. + + The "best" mode should be close to unbeatable (except for lucky guesses). + """ + + def __init__(self, mode: str, samples_jsonl: str, registry: Registry): + assert mode in ("random", "best"), "mode must be either random or best" + self.mode = mode + self._all_samples = self._get_samples(samples_jsonl, registry._registry_paths[0]) + self._rng = np.random.default_rng() + + def _solve(self, task_state: TaskState): + cs: CurrentState = task_state.current_state + + matching_samples = self._get_matching_samples(cs.known_values) + if len(matching_samples) > 1: + if self.mode == "random": + response = self._get_ask_random(cs.known_values) + else: + response = self._get_ask_best(matching_samples) + else: + sample_values = matching_samples[0].values + result = [sample_values[test_input] for test_input in cs.test_inputs] + response = " ".join([str(x) for x in result]) + return SolverResult(str(response)) + + def _get_matching_samples(self, known_values): + def matches(sample: Sample) -> bool: + for key, val in known_values.items(): + if sample.values[key] != val: + return False + return True + + return [sample for sample in self._all_samples if matches(sample)] + + def _get_ask_best(self, samples): + def get_entropy(x: int) -> float: + values = [sample.values[x] for sample in samples] + counter = Counter(values) + return entropy([val for val in counter.values()]) + + return max(range(0, 101), key=get_entropy) + + def _get_ask_random(self, known_values): + while True: + x = self._rng.integers(0, 100) + if x not in known_values: + return x + + def _get_samples(self, samples_jsonl: str, registry_path: Path): + path = registry_path / "data" / samples_jsonl + return [Sample(**x) for x in get_jsonl(path.as_posix())] diff --git a/evals/elsuite/function_deduction/eval.py b/evals/elsuite/function_deduction/eval.py new file mode 100644 index 0000000000..6542852153 --- /dev/null +++ b/evals/elsuite/function_deduction/eval.py @@ -0,0 +1,302 @@ +import logging +import random +import re +from dataclasses import dataclass, field +from typing import List, Literal, Optional, Tuple, Union + +import numpy as np +import scipy + +import evals +from evals.api import CompletionFn +from evals.elsuite.function_deduction import prompts +from evals.eval import SolverEval +from evals.solvers.solver import Solver +from evals.task_state import Message, TaskState + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class Sample: + sample_ix: int + code: str + complexity: int + range: List[int] + values: List[int] + + +@dataclass +class CurrentState: + """This class tracks all the information from the dialog with the model. + + Some things are tracked to make writing solvers easier. + Other are tracked for metrics. + """ + + n_rounds: int + mode: str + test_inputs: tuple[int, int, int] + success: bool = False + known_values: dict[int, int] = field(default_factory=dict) + negative_known_values: dict[int, int] = field(default_factory=dict) + ask_rounds: int = 0 + guess_rounds: int = 0 + incorrect_format_rounds: int = 0 + parsed_responses: list[tuple[int]] = field(default_factory=list) + + @property + def round_ix(self): + return self.ask_rounds + self.guess_rounds + self.incorrect_format_rounds + + def ask_update(self, input_: int, value: Optional[int]) -> None: + self.ask_rounds += 1 + self.parsed_responses.append((input_,)) + if value is not None: + self.known_values[input_] = value + + def guess_update( + self, guessed_ints: tuple[int, int, int], expected_ints: tuple[int, int, int] + ) -> None: + self.guess_rounds += 1 + self.parsed_responses.append(guessed_ints) + if guessed_ints == expected_ints: + self.success = True + + if self.mode == "easy": + for test, guess, correct in zip(self.test_inputs, guessed_ints, expected_ints): + if guess == correct: + self.known_values[test] = guess + else: + self.negative_known_values[test] = guess + + +class FunctionDeductionEval(SolverEval): + def __init__( + self, + completion_fns: list[CompletionFn], + mode: Literal["easy", "hard"], + n_rounds: int, + n_samples: Optional[int] = None, + n_repeat: int = 3, + failed_sample_rounds: Optional[int] = None, + seed: Optional[int] = None, + samples_jsonl: str = "function_deduction/data.jsonl", + *args, + **kwargs, + ): + super().__init__(completion_fns, seed=seed, samples_jsonl=samples_jsonl, *args, **kwargs) + + self.mode = mode + self.n_rounds = n_rounds + self.n_samples = n_samples + self.n_repeat = n_repeat + + # This is used for the main metric - "how many rounds for a sample that was not solved?" + self.failed_sample_rounds = ( + failed_sample_rounds if failed_sample_rounds is not None else n_rounds * 2 + ) + + def eval_sample(self, solver: Solver, sample: Sample, rng: random.Random): + test_inputs = rng.sample(range(101), 3) + values = sample.values + expected = tuple(sample.values[test_input] for test_input in test_inputs) + + cs = CurrentState(self.n_rounds, self.mode, test_inputs) + task_state = TaskState( + prompts.task_description.format(inputs=test_inputs, n_rounds=self.n_rounds), + current_state=cs, + ) + + for round_ix in range(self.n_rounds): + raw_response = solver(task_state).output + try: + ints = self._parse_raw_response(raw_response) + except ValueError: + cs.incorrect_format_rounds += 1 + answer = prompts.incorrect_format + else: + if len(ints) == 1: + ask = ints[0] + result = values[ask] if ask not in test_inputs else None + cs.ask_update(ask, result) + if result is None: + answer = prompts.test_input_not_allowed.format(inputs=test_inputs) + else: + answer = prompts.new_value.format(in_=ask, out=result) + else: + cs.guess_update(ints, expected) + if cs.success: + break + else: + answer = self._bad_guess_answer(test_inputs, ints, expected) + + task_state.messages += [ + Message("assistant", raw_response), + Message("system", answer), + ] + + evals.record.record_metrics( + sample_ix=sample.sample_ix, + success=cs.success, + num_rounds=cs.round_ix if cs.success else None, + ask_rounds=cs.ask_rounds, + guess_rounds=cs.guess_rounds, + incorrect_format_rounds=cs.incorrect_format_rounds, + repeated_rounds=len(cs.parsed_responses) - len(set(cs.parsed_responses)), + code="lambda x: " + sample.code, + complexity=sample.complexity, + ) + + def run(self, recorder: evals.record.Recorder): + samples = self.get_samples() + + # Add copies according to self.n_repeat + # NOTE: we have copies next to each other -> more convenient when reading in logviz + copied_samples = [sample for sample in samples for _ in range(self.n_repeat)] + logger.info( + f"{len(samples)} unique samples, {self.n_repeat} attempts for each sample, {len(copied_samples)} total samples" + ) + self.eval_all_samples(recorder, copied_samples) + metrics = recorder.get_metrics() + + adjusted_rounds = [x["num_rounds"] or self.failed_sample_rounds for x in metrics] + main_metric = sum(adjusted_rounds) / len(metrics) + result = { + "adjusted_avg_score": main_metric, + "sem_adjusted_avg_score": self._calculate_sem(adjusted_rounds), + } + + result.update(self._get_success_metrics(metrics)) + result.update(self._get_sample_std(metrics)) + for name in ("ask_rounds", "guess_rounds", "incorrect_format_rounds"): + result[f"avg_{name}"] = sum(x[name] for x in metrics) / len(metrics) + result[f"sem_avg_{name}"] = self._calculate_sem([x[name] for x in metrics]) + result.update(self._get_complexity_tests(metrics)) + result.update(self._get_per_complexity_metrics(metrics)) + + return result + + def _calculate_sem(self, values: list) -> float: + return np.std(values) / np.sqrt(len(values)) + + def _get_success_metrics(self, metrics): + success = [x for x in metrics if x["success"]] + return { + "solved_ratio": round(len(success) / len(metrics), 2), + "sem_solved_ratio": self._calculate_sem([x["success"] for x in metrics]), + "solved": len(success), + "samples": len(metrics), + "avg_success_rounds": round(sum(x["num_rounds"] for x in success) / len(success), 2) + if success + else None, + "sem_avg_success_rounds": self._calculate_sem([x["num_rounds"] for x in success]) + if success + else None, + } + + def _get_sample_std(self, metrics): + adjusted = [] + no_failed = [] + solved_ratio_if_any_solved = [] + sample_ixs = set(metric["sample_ix"] for metric in metrics) + for sample_ix in sample_ixs: + sample_metrics = [metric for metric in metrics if metric["sample_ix"] == sample_ix] + sample_adjusted = [ + metric["num_rounds"] or self.failed_sample_rounds for metric in sample_metrics + ] + sample_no_failed = [ + metric["num_rounds"] for metric in sample_metrics if metric["success"] + ] + solved_ratio = sum(1 for metric in sample_metrics if metric["success"]) / len( + sample_metrics + ) + + if len(sample_adjusted) > 1: + adjusted.append(np.std(sample_adjusted)) + if len(sample_no_failed) > 1: + no_failed.append(np.std(sample_no_failed)) + if solved_ratio: + solved_ratio_if_any_solved.append(solved_ratio) + + return { + "avg_sample_rounds_std_adjusted": sum(adjusted) / len(adjusted) if adjusted else None, + "avg_sample_rounds_std_no_failed": sum(no_failed) / len(no_failed) + if no_failed + else None, + # This is just solved_ratio but excluding samples that had no succesful attempt. + # So 1 is full stability (i.e. if sample was solved once, it will be solved always), + # and (1/self.n_repeat) is "no sample was solved more than once" + "solved_ratio_if_any_solved": sum(solved_ratio_if_any_solved) + / len(solved_ratio_if_any_solved) + if solved_ratio_if_any_solved + else None, + } + + def _get_complexity_tests(self, metrics): + solved = [x["complexity"] for x in metrics if x["success"]] + not_solved = [x["complexity"] for x in metrics if not x["success"]] + result = { + "solved_avg_complexity": sum(solved) / len(solved) if solved else None, + "not_solved_avg_complexity": sum(not_solved) / len(not_solved) if not_solved else None, + } + + # This tests if solved have lower complexity than non-solved + if solved and not_solved: + _, p_value = scipy.stats.mannwhitneyu(solved, not_solved, alternative="less") + else: + p_value = None + result["solved_or_not_mann_whitney_u_p_value"] = p_value + + # TODO: add more complexity-related metrics, such as correlation or linear regression coefficient. + # Leaving this for the future because we might want to change how the complexity is calculated, + # or generally improve the concept somehow. + + return result + + def _get_per_complexity_metrics(self, all_metrics): + complexity_values = sorted(x["complexity"] for x in all_metrics) + result = {} + for complexity in complexity_values: + metrics = [x for x in all_metrics if x["complexity"] == complexity] + result[f"complexity_{complexity}"] = self._get_success_metrics(metrics) + return result + + def _parse_raw_response(self, response: str) -> Union[Tuple[int], Tuple[int, int, int]]: + # Remove all non-numbers first. This way we accept also e.g. "1, 2, 3", "[1, 2, 3]", '"1", "2", "3"' etc. + response = re.sub(r"[^0-9\s-]", "", response) + + vals = tuple(int(x) for x in response.split()) + if len(vals) not in (1, 3): + raise ValueError("Expected 1 or 3 integers") + if len(vals) == 1 and not 0 <= vals[0] <= 100: + raise ValueError("Single int should be between 0 and 100") + return vals + + def _bad_guess_answer(self, test_inputs, guessed, expected) -> str: + correct = [test_inputs[i] for i in range(0, 3) if guessed[i] == expected[i]] + incorrect = [x for x in test_inputs if x not in correct] + assert incorrect, "This is not a bad answer" + + if self.mode == "hard": + return "This is not the correct answer. At least one of the values is wrong." + elif self.mode == "easy": + if len(correct) == 0: + return "All numbers are wrong." + elif len(correct) == 1: + return f"Your guess is correct for {correct[0]} and incorrect for {incorrect[0]} and {incorrect[1]}" + elif len(correct) == 2: + return f"Your guess is correct for {correct[0]} and {correct[1]} and incorrect for {incorrect[0]}" + else: + raise ValueError(f"Unknown mode {self.mode}") + + def get_samples(self) -> List[Sample]: + samples = super().get_samples() + + if self.n_samples is not None: + assert ( + len(samples) >= self.n_samples + ), f"Can't get {self.n_samples} samples from a dataset with {len(samples)} samples" + np.random.default_rng(seed=self.seed).shuffle(samples) + samples = samples[: self.n_samples] + return [Sample(**sample_dict) for sample_dict in samples] diff --git a/evals/elsuite/function_deduction/prompts.py b/evals/elsuite/function_deduction/prompts.py new file mode 100644 index 0000000000..948e14079f --- /dev/null +++ b/evals/elsuite/function_deduction/prompts.py @@ -0,0 +1,43 @@ +task_description = """\ +I have an important task for you. Instructions below. + +I have a deterministic function `def calc(x: int) -> int:` that accepts an integer in range [0, 100] as an argument and returns an integer. + +I want you to correctly determine the result of this function for each of the following inputs: {inputs} + +At each turn you must select one of the following 2 options: +Option (A): Gather useful information. Ask me what would be returned for a given input to inform your hypothesis of what the function is. \ +This must be an integer in range [0, 100], other than {inputs}. +Option (B): State your guess for the values returned by the function for inputs {inputs}. Your guess must include all three numbers, in this order. \ +If your guess is correct, you win. If not, the task will continue. + +You have {n_rounds} turns to determine the correct answer, else you lose. + +Your goal is to determine the correct answer in as few turns as possible, so choose your actions carefully! +Your final score will be: +* If you make the correct guess: {n_rounds} - number_of_rounds_played +* If you fail to make the correct guess: -{n_rounds} +You should try to score as high as possible. + +You should respond either with a single integer in range 0-100 if you decide to take option A or with three integers if you make the final guess (i.e. use option B). +Don't say anything more except for either an integer or 3 integers. +""" +# TODO: Include in the task description whether we're in the easy or hard mode? +# (i.e. whether we'll be told about correctly guessed numbers (if any) or not). + +incorrect_format = "You responded in an incorrect format. Respond with either a single integer in range 0-100 or 3 integers. Just one number or three numbers, nothing more." +new_value = "calc({in_}) = {out}" +test_input_not_allowed = "Numbers {inputs} are not allowed" + +cot_template_first_round = """This is the first round. Think out loud about a general strategy of solving tasks like this.""" +cot_template_later_rounds = """\ +This is round {round_ix} out of {num_rounds}. + +Summary of the information you have already gathered: +{summary} + +Think out loud about the following questions: +* Do you have any hypothesis on what this function might be doing? +* If yes, should you try to test it (how?), or just use it to calculate the answer? +* If not, what additional information should you gather to be able to formulate a hypothesis? +""" diff --git a/evals/elsuite/function_deduction/scripts/dataset/create_dataset.py b/evals/elsuite/function_deduction/scripts/dataset/create_dataset.py new file mode 100644 index 0000000000..931e1cc27a --- /dev/null +++ b/evals/elsuite/function_deduction/scripts/dataset/create_dataset.py @@ -0,0 +1,62 @@ +import argparse +import dis +import json +import math + +DEFAULT_RANGE = [0, 100] # inclusive + + +def get_func_from_code(code): + return lambda x: eval(code, {"math": math, "x": x}) + + +def get_complexity(code: str) -> int: + # NOTE: this is quite ugly, but should be good enough for dataset-creating code + code = "global func_name\ndef func_name(x): return " + code + exec(code) + return len(list(dis.get_instructions(func_name))) + + +def create_dataset(out_file, in_file): + samples = [] + + for line in in_file: + line = line.strip() + + if not line or line.startswith("#"): + continue + + func = get_func_from_code(line) + values = list(int(func(x)) for x in range(DEFAULT_RANGE[0], DEFAULT_RANGE[1] + 1)) + samples.append( + { + "code": line, + "complexity": get_complexity(line), + "range": DEFAULT_RANGE, + "values": values, + } + ) + + # Ensure we don't have duplicates - they might be different functions, but if they return the same + # value for every input in the DEFAULT_RANGE then they are in fact the same sample. + for sample_ix, sample in enumerate(samples): + for other_sample in samples[sample_ix + 1 :]: + if sample["values"] == other_sample["values"]: + raise ValueError( + f"Samples {sample['code']} and {other_sample['code']} are indistinguishable" + ) + + samples.sort(key=lambda x: x["complexity"]) + + for i, sample in enumerate(samples): + sample = dict(sample_ix=i, **sample) + json.dump(sample, out_file) + out_file.write("\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--out", type=argparse.FileType("w"), required=True) + parser.add_argument("--in", dest="in_", type=argparse.FileType("r"), default="raw_code.txt") + args = parser.parse_args() + create_dataset(out_file=args.out, in_file=args.in_) diff --git a/evals/elsuite/function_deduction/scripts/dataset/raw_code.txt b/evals/elsuite/function_deduction/scripts/dataset/raw_code.txt new file mode 100644 index 0000000000..ff03a0c76e --- /dev/null +++ b/evals/elsuite/function_deduction/scripts/dataset/raw_code.txt @@ -0,0 +1,141 @@ +# Lines starting with '#' or empty are ignored. +# Every other line is code for a single sample. +# This file is parsed by create_datset.py script +# (-> is not accessed when the eval is running). + +# Single operation +x +x * 2 +x * 27 +-x +x * -2 +x * -19 +math.floor(x * 1.5) +math.floor(x * 8.5) +math.floor(x / 2) +math.floor(x / 10) +math.ceil(x / 2) +round(x / 10) +math.ceil(x / 10) +x + 1 +x + 17 +x - 1 +x - 29 +7 - x +x ** 2 +x ** 3 + +# Small set of values +7 +7 if x % 2 else 17 +x % 3 +x % 7 +x % 10 +int(x % 3 == 1) +int(x % 3 == 2) +int(x % 3 == 1) * 7 +int(x % 3 == 2) * 18 +int(x < 32) +int(x % 8 < 4) + +# Threshold +min(x, 30) +max(x, 30) +min(x * 2, 70) +max(x * 2, 70) +x * 2 if x < 50 else x +x + 7 if x < 50 else x - 7 +x + 50 if x < 50 else 100 - x +x * 2 if x > 40 else x * 3 +3 if 30 < x < 70 else 4 +min(1000000, 2 ** x) + +# Multiple operations +math.floor(x + math.sqrt(x)) +math.floor(math.sqrt(x)) +math.floor(math.sqrt(x)) - 1 +math.floor(math.sqrt(x)) * 2 +math.floor(math.sqrt(x) * 2) +math.floor(round(x ** (1/3), 8)) +x / 2 if not x % 2 else x * 3 +x / 2 if not x % 2 else x * 3 + 1 +x ** 2 if x % 2 else x ** 3 +x / 3 if not x % 3 else x +x / 3 if not x % 3 else x * 2 +(x + 1) / 3 if x % 3 == 2 else x +x ** 2 - 10 +x ** 3 - x ** 2 +x ** 2 * 2 +x * (x - 1) +x * (x - 1) * (x - 2) +x * (x + 1) / 2 +5 - (x % 5) +10 - (x % 10) +16 - (x % 16) +x - x % 6 +x - x % 15 +x - x % 10 +x + x % 10 +x + x % 4 +x + x // 10 +x + x // 8 +x // 10 + x % 2 +(x + 5) * 3 +(x + 2) * 7 +(2 * x) ** 2 + + +# Math, sin, cos etc +round(math.sin(x)) +round(math.sin(x * 0.5 * math.pi)) +round(math.sin(x * 0.25 * math.pi) * 10) +round(math.sin(x * 0.1 * math.pi) * 10) +round(math.cos(x)) +round(math.cos(x * 0.5 * math.pi)) +round(math.cos(x * 0.25 * math.pi) * 10) +round(math.cos(x * 0.1 * math.pi) * 10) + +# Is prime number? +int(x > 1 and all(x % i for i in range(2, x))) +x if x > 1 and all(x % i for i in range(2, x)) else x + 1 + +# Is perfect square? +int(int(x**0.5)**2 == x) + +# Divisors - number / sum +sum(1 for i in range(1, x + 1) if not x % i) +sum(i for i in range(1, x + 1) if not x % i) + +# Reverse digits +int(str(x)[::-1]) +abs(x - int(str(x)[::-1])) +x + int(str(x)[::-1]) + +# Sum of digits +sum(int(d) for d in str(x)) +x + sum(int(d) for d in str(x)) +int(sum(int(d) for d in str(x)) % 10) + +# Count odd/even digits +sum(1 for d in str(x) if int(d) % 2) +sum(1 for d in str(x) if not int(d) % 2) + +# Multiple digits +0 if x < 10 else (x % 10) * (x // 10) + +# Higher vs lower digit +0 if x < 10 else max(int(d) for d in str(x)) - min(int(d) for d in str(x)) + +# Other +bin(x).count("1") +x | 1 +int(str(x) == str(x)[::-1]) +x * int(str(x)[-1]) + +# More ideas: convert to binary +# int(bin(x)[2:]) +# int(bin(~x)[3:]) +# int(bin(x * 2)[2:]) + +# More ideas: highest divisor lower than x? +# 0 if x == 0 else max(1 for i in range(1, x) if not x % i) diff --git a/evals/elsuite/function_deduction/scripts/make_plots.py b/evals/elsuite/function_deduction/scripts/make_plots.py new file mode 100644 index 0000000000..4c8f5f5e78 --- /dev/null +++ b/evals/elsuite/function_deduction/scripts/make_plots.py @@ -0,0 +1,256 @@ +import argparse +from pathlib import Path + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns + +from evals.utils import log_utils + +palette = { + "Average Baseline": "blue", + "Full Knowledge Best": "blue", + "Full Knowledge Random": "blue", + + "Human": "steelblue", + + "gpt-4-32k": "purple", + "gpt-4-32k w CoT": "purple", + + "gpt-4-base w Few-shot": "orange", + "gpt-4-base w CoT and Few-shot": "orange", + + "gpt-3.5-turbo-16k": "green", + "gpt-3.5-turbo-16k w CoT": "green", + + "gemini-pro": "peru", + "gemini-pro w CoT": "peru", + + "llama-2-13b-chat": "brown", + "llama-2-13b-chat w CoT": "brown", + + "llama-2-70b-chat": "maroon", + "llama-2-70b-chat w CoT": "maroon", + + "mixtral-8x7b-instruct": "grey", + "mixtral-8x7b-instruct w CoT": "grey", +} + +solver_to_name = { + "function_deduction/full_knowledge_best": "Full Knowledge Best", + "function_deduction/full_knowledge_random": "Full Knowledge Random", + "function_deduction/average_baseline": "Average Baseline", + + "human_cli": "Human", + + "gpt-4-32k": "gpt-4-32k", + "function_deduction/cot/gpt-4-32k": "gpt-4-32k w CoT", + + "function_deduction/gpt-4-base": "gpt-4-base w Few-shot", + "function_deduction/cot/gpt-4-base": "gpt-4-base w CoT and Few-shot", + + "gpt-3.5-turbo-16k": "gpt-3.5-turbo-16k", + "function_deduction/cot/gpt-3.5-turbo-16k": "gpt-3.5-turbo-16k w CoT", + + "generation/direct/gemini-pro": "gemini-pro", + "function_deduction/cot/gemini-pro": "gemini-pro w CoT", + + "generation/direct/llama-2-13b-chat": "llama-2-13b-chat", + "function_deduction/cot/llama-2-13b-chat": "llama-2-13b-chat w CoT", + + "generation/direct/llama-2-70b-chat": "llama-2-70b-chat", + "function_deduction/cot/llama-2-70b-chat": "llama-2-70b-chat w CoT", + + "generation/direct/mixtral-8x7b-instruct": "mixtral-8x7b-instruct", + "function_deduction/cot/mixtral-8x7b-instruct": "mixtral-8x7b-instruct w CoT", +} + +rename_columns = { + "adjusted_avg_rounds": "adjusted_avg_score", + "sem_adjusted_avg_rounds": "sem_adjusted_avg_score", +} + + +def extract_final_reports( + datadir: Path, rename_solvers: dict, rename_columns: dict +) -> pd.DataFrame: + df_rows = [] + for path, results in sorted(list(log_utils.get_final_results_from_dir(datadir).items())): + spec = log_utils.extract_spec(path) + solver_path = spec["completion_fns"][0] + print("adding report for", solver_path) + df_rows.append( + { + "solver": rename_solvers.get(solver_path, solver_path), + **{rename_columns.get(k, k): v for k, v in results.items()}, + } + ) + df = pd.DataFrame(df_rows) + return df + + +def make_plot( + df, + x_column: str, + y_column: str, + x_err_column: str, + title: str, + xlabel: str, + ylabel: str, + out_path: Path, +): + # Avg rounds until success (failure counts as 40) + plt.figure(figsize=(10, 6)) + ax = sns.barplot( + x=x_column, + y=y_column, + data=df, + xerr=df[x_err_column] * 1.96, + palette=palette, + ) + + plt.xlabel(xlabel) + plt.ylabel(ylabel) + plt.title(title) + plt.grid(axis="x") + plt.tight_layout() + + # Expanding the x-axis limit + x_lim = ax.get_xlim() + ax.set_xlim([x_lim[0], x_lim[1] * 1.05]) # Increase the upper limit by 5% + + # Annotating each bar with its value + for p in ax.patches: + width = p.get_width() + ax.text( + width + x_lim[1] * 0.02, # x position of text + p.get_y() + p.get_height() / 2, # y position of text + "{:.1f}".format(width), # text to be shown + va="center", + ) # vertical alignment + + plt.savefig(out_path) + return + + +def make_ask_guess_incorrect_plot(df, out_path: Path): + # Ask/Guess/Incorrect + + ask_guess_incorrect_data = { + "solver": df["solver"], + "Ask": df["avg_ask_rounds"], + "SEM Average Ask Rounds": df["sem_avg_ask_rounds"], + "Guess": df["avg_guess_rounds"], + "SEM Average Guess Rounds": df["sem_avg_guess_rounds"], + "Incorrect Format": df["avg_incorrect_format_rounds"], + "SEM Average Incorrect Format Rounds": df["sem_avg_incorrect_format_rounds"], + } + + agi_palette = { + "Ask": "blue", + "Guess": "pink", + "Incorrect Format": "red", + } + + ask_guess_incorrect_df = pd.DataFrame(ask_guess_incorrect_data) + + # Melting the DataFrame to make it suitable for seaborn's factorplot + melted_df = pd.melt( + ask_guess_incorrect_df, + id_vars="solver", + value_vars=["Ask", "Guess", "Incorrect Format"], + var_name="Round Type", + value_name="Average Rounds", + ) + + # Generating the plot for Average Ask/Guess/Incorrect Format Rounds + plt.figure(figsize=(14, 14)) + ax = sns.barplot( + x="Average Rounds", y="solver", hue="Round Type", data=melted_df, palette=agi_palette + ) + + plt.xlabel("Average Number of Rounds") + plt.ylabel("Solver") + plt.title("Distribution of Type of Responses by Model") + plt.grid(axis="x") + plt.legend(title="Response Type") + plt.tight_layout() + + # Expanding the x-axis limit + x_lim = ax.get_xlim() + ax.set_xlim([x_lim[0], x_lim[1] * 1.05]) # Increase the upper limit by 5% + + # Annotating each bar with its value + for p in ax.patches: + width = p.get_width() + ax.text( + width + 0.1, # x position of text + p.get_y() + p.get_height() / 2, # y position of text + "{:.1f}".format(width), # text to be shown + va="center", + ) # vertical alignment + + plt.savefig(out_path) + return + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--log-dir", "-d", type=str, required=True) + parser.add_argument("--out-dir", "-o", type=str, default="./outputs") + args = parser.parse_args() + log_dir = Path(args.log_dir) + out_dir = Path(args.out_dir) + + df = extract_final_reports(log_dir, solver_to_name, rename_columns) + + # Drop all columns named "complexity*" + df = df[df.columns.drop(list(df.filter(regex="complexity")))] + + # Creating a new DataFrame with the desired order + ordered_df = df.set_index("solver").loc[list(solver_to_name.values())].reset_index() + print(ordered_df) + + make_plot( + df=ordered_df, + x_column="adjusted_avg_score", + y_column="solver", + x_err_column="sem_adjusted_avg_score", + title="Adjusted Average Score (Lower is Better)", + xlabel="Adjusted Average Score", + ylabel="Solver", + out_path=out_dir / "avg_adjusted_score.png", + ) + + ordered_df["solved_ratio"] = 100 * ordered_df["solved_ratio"] + ordered_df["sem_solved_ratio"] = 100 * ordered_df["sem_solved_ratio"] + make_plot( + df=ordered_df, + x_column="solved_ratio", + y_column="solver", + x_err_column="sem_solved_ratio", + title="Solved Samples Ratio (Higher is Better)", + xlabel="Solved Ratio (%)", + ylabel="Solver", + out_path=out_dir / "solved_ratio.png", + ) + + make_plot( + df=ordered_df, + x_column="avg_success_rounds", + y_column="solver", + x_err_column="sem_avg_success_rounds", + title="Average Number of Rounds for Solved Samples (Lower is Better)", + xlabel="No. of Rounds", + ylabel="Solver", + out_path=out_dir / "avg_success_rounds.png", + ) + + make_ask_guess_incorrect_plot( + df=ordered_df, + out_path=out_dir / "ask_guess_incorrect.png", + ) + + +if __name__ == "__main__": + main() diff --git a/evals/elsuite/function_deduction/scripts/run_experiments.sh b/evals/elsuite/function_deduction/scripts/run_experiments.sh new file mode 100755 index 0000000000..4e67f5c7be --- /dev/null +++ b/evals/elsuite/function_deduction/scripts/run_experiments.sh @@ -0,0 +1,27 @@ + +logdir=./logs +timestamp=$(date +%Y%m%d_%H%M%S) +logpathbase="$logdir/$timestamp" + +echo Running experiments and logging to $logpathbase + +# Baselines +oaieval function_deduction/average_baseline function_deduction.easy --record_path "$logpathbase/average_baseline.log" +oaieval function_deduction/full_knowledge_best function_deduction.easy --record_path "$logpathbase/full_knowledge_best.log" +oaieval function_deduction/full_knowledge_random function_deduction.easy --record_path "$logpathbase/full_knowledge_random.log" --extra_eval_params n_repeat=100 + +declare -a SOLVERS=( + gpt-3.5-turbo-16k + gpt-4-32k + function_deduction/cot/gpt-3.5-turbo-16k + function_deduction/cot/gpt-4-32k + function_deduction/gpt-4-base + function_deduction/cot/gpt-4-base +) + +# Models +for solver in "${SOLVERS[@]}" +do + log_name=${solver//\//-} + oaieval $solver function_deduction.easy --record_path "$logpathbase/$log_name.log" +done diff --git a/evals/elsuite/function_deduction/solvers.py b/evals/elsuite/function_deduction/solvers.py new file mode 100644 index 0000000000..4830afe34a --- /dev/null +++ b/evals/elsuite/function_deduction/solvers.py @@ -0,0 +1,173 @@ +from typing import Any + +from evals.elsuite.function_deduction import prompts +from evals.elsuite.function_deduction.eval import CurrentState +from evals.solvers.nested.cot_solver import CoTSolver +from evals.solvers.nested.hhh_solver import HHHSolver +from evals.solvers.solver import SolverResult, SolverSpec +from evals.task_state import Message, TaskState + + +class CustomCoT(CoTSolver): + def __init__( + self, + cot_solver: SolverSpec, + extract_solver: SolverSpec, + persistent_memory: bool = True, + registry: Any = None, + ): + super().__init__( + cot_solver=cot_solver, + extract_solver=extract_solver, + persistent_memory=persistent_memory, + ) + + def cot_template(self, task_state: TaskState) -> str: + round_ix = task_state.current_state.round_ix + if round_ix == 0: + return prompts.cot_template_first_round + else: + summary = self._get_summary(task_state.current_state) + return prompts.cot_template_later_rounds.format( + round_ix=round_ix + 1, # displayed round number starts from 1 + num_rounds=task_state.current_state.n_rounds, + summary=summary, + ) + + def _get_summary(self, current_state: CurrentState) -> str: + rows = [] + for key, val in sorted(current_state.known_values.items()): + rows.append(f"calc({key}) = {val}") + + negative_rows = [] + for key, val in sorted(current_state.negative_known_values.items()): + negative_rows.append(f"calc({key}) != {val}") + + parts = [] + if rows: + parts.append("\n".join(rows)) + if negative_rows: + msg = "Information from your incorrect guesses:\n" + parts.append(msg + "\n".join(negative_rows)) + + if not parts: + return "You don't know anything yet." + else: + return "\n\n".join(parts) + + +class BaseModelSolver(HHHSolver): + def _solve(self, task_state: TaskState): + task_state = TaskState( + task_state.task_description, + self._few_shot_messages() + task_state.messages, + task_state.current_state, + ) + result = super()._solve(task_state) + result = result.output.splitlines()[0] + return SolverResult(result) + + def _few_shot_messages(self) -> list[Message]: + role = "system" + messages = [ + (role, "I have a hidden function. What is your first action?"), + ("assistant", "40"), + (role, "calc(40) = 160"), + ("assistant", "52"), + (role, "calc(52) = 204"), + ("assistant", "68 144 272"), + (role, "Correct guess!"), + (role, "I now have a new function. Forget about the previous one, we start again."), + ] + return [Message(*row) for row in messages] + + +class BaseModelCoTSolver(CustomCoT): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @property + def cot_solver(self): + # This is a hack to enable us to postprocess the output during the CoT generation step: + # Here, we always keep just first line of the CoT generation, otherwise the base model + # goes on to copy the extraction prompt and confuses itself. + # TODO: Once we have solvers with proper built-in support for output postprocessors, + # update this to use that instead. + def cot_solver(task_state): + result = self.get_solver("cot_solver")(task_state).output + result = result.splitlines()[0] + return SolverResult(result) + + return cot_solver + + def _solve(self, task_state: TaskState): + task_state = TaskState( + task_state.task_description, + self._few_shot_messages(task_state.current_state) + task_state.messages, + task_state.current_state, + ) + + result = super()._solve(task_state) + result = result.output.splitlines()[0] + + # Fix the interaction history so that we can have persistent_memory = True + self.interaction_cache.last_interaction.messages[-1] = Message("assistant", result) + + return SolverResult(result) + + def _few_shot_messages(self, current_state) -> list[Message]: + # This is a bit hackish, but this way we can use self.cot_template (defined on CustomCoT), + # -> we'll have exactly the same system prompts in few shot part as we'd have in the eval + cs = CurrentState(current_state.n_rounds, current_state.mode, current_state.test_inputs) + task_state = TaskState("", [], cs) + + role = "system" + + messages = [ + (role, "I have a hidden function. What is your first action?"), + ] + messages.append((role, self.cot_template(task_state))) + messages.append( + ( + "assistant", + ( + "I don't know anything about the function yet. I should start with gathering some information, " + "e.g. by asking about the return value for 40." + ), + ) + ) + messages.append((role, self.extract_template(task_state))) + messages.append(("assistant", "40")) + messages.append((role, "calc(40) = 160")) + cs.ask_update(40, 160) + + messages.append((role, self.cot_template(task_state))) + messages.append( + ( + "assistant", + "Perhaps this function does multiplication by 4? That's not certain yet, so let's try some other number, e.g. 52.", + ) + ) + messages.append((role, self.extract_template(task_state))) + messages.append(("assistant", "52")) + messages.append((role, "calc(52) = 204")) + cs.ask_update(52, 204) + + messages.append((role, self.cot_template(task_state))) + messages.append( + ( + "assistant", + ( + "Now we have two results where the ouput is the input times 4. It seems that the function multiplies by 4. " + "I will make the guess now. 17 * 4 = 68, 36 * 4 = 144 and 68 * 4 = 272, so my guess will be 68 144 272." + ), + ) + ) + messages.append((role, self.extract_template(task_state))) + messages.append(("assistant", "68 144 272")) + messages.append((role, "Correct guess!")) + messages.append( + (role, "I now have a new function. Forget about the previous one, we start again.") + ) + + return [Message(*row) for row in messages] diff --git a/evals/elsuite/function_deduction/solvers_test.py b/evals/elsuite/function_deduction/solvers_test.py new file mode 100644 index 0000000000..8fadec107f --- /dev/null +++ b/evals/elsuite/function_deduction/solvers_test.py @@ -0,0 +1,149 @@ +from evals.elsuite.function_deduction.eval import CurrentState +from evals.elsuite.function_deduction.prompts import ( + cot_template_first_round, + cot_template_later_rounds, +) +from evals.elsuite.function_deduction.solvers import BaseModelCoTSolver, CustomCoT +from evals.solvers.solver import SolverSpec +from evals.task_state import Message, TaskState + +dummy_solver_spec = SolverSpec( + { + "class": "evals.solvers.solver:DummySolver", + "args": {}, + } +) + +GUESS_INPUT = 7 +ANSWER = 0 +N_ROUNDS = 10 +ROUNDS_SIMULATED = 2 +MODE = "easy" +TEST_INPUTS = (10, 20, 30) + + +def simulate_dummy_game(solver): + # Init state + task_description = "" # Not used + msgs = [] + cs = CurrentState( + n_rounds=N_ROUNDS, + mode=MODE, + test_inputs=TEST_INPUTS, + ) + + # ROUND 1 + solver_result = solver( + TaskState( + task_description=task_description, + messages=msgs, + current_state=cs, + ) + ) + + msgs.append(Message("assistant", solver_result.output)) + msgs.append(Message("system", f"The answer to your query is {ANSWER}")) + cs.ask_update(GUESS_INPUT, ANSWER) # Collect data for input=7 + + # ROUND 2 + solver_result = solver( + TaskState( + task_description=task_description, + messages=msgs, + current_state=cs, + ) + ) + return solver + + +def test_custom_cot(): + solver = CustomCoT(dummy_solver_spec, dummy_solver_spec) + simulate_dummy_game(solver) + + # Check that the customized CoT generation prompts appear as expected + # (and that the persistent memory in fact persists) + solver_private_memory = solver.interaction_cache.last_interaction.messages + assert solver_private_memory[0].content == cot_template_first_round + assert solver_private_memory[2].content == solver._extract_template + assert solver_private_memory[5].content == cot_template_later_rounds.format( + round_ix=ROUNDS_SIMULATED, + num_rounds=N_ROUNDS, + summary=f"calc({GUESS_INPUT}) = {ANSWER}", + ) + assert solver_private_memory[7].content == solver._extract_template + + +def test_base_model_cot_solver(): + solver = BaseModelCoTSolver(dummy_solver_spec, dummy_solver_spec) + simulate_dummy_game(solver) + + # Check that the memory contains the few-shot prompts + # followed by the customized CoT generation prompts + solver_private_memory = solver.interaction_cache.last_interaction.messages + + expected_few_shot_msgs = [ + Message(role="system", content="I have a hidden function. What is your first action?"), + Message( + role="system", + content="This is the first round. Think out loud about a general strategy of solving tasks like this.", + ), + Message( + role="assistant", + content="I don't know anything about the function yet. I should start with gathering some information, e.g. by asking about the return value for 40.", + ), + Message( + role="system", + content="Given the above reasoning, the answer in the format requested by the question is:", + ), + Message(role="assistant", content="40"), + Message(role="system", content="calc(40) = 160"), + Message( + role="system", + content="This is round 2 out of 10.\n\nSummary of the information you have already gathered:\ncalc(40) = 160\n\nThink out loud about the following questions:\n* Do you have any hypothesis on what this function might be doing?\n* If yes, should you try to test it (how?), or just use it to calculate the answer?\n* If not, what additional information should you gather to be able to formulate a hypothesis?\n", + ), + Message( + role="assistant", + content="Perhaps this function does multiplication by 4? That's not certain yet, so let's try some other number, e.g. 52.", + ), + Message( + role="system", + content="Given the above reasoning, the answer in the format requested by the question is:", + ), + Message(role="assistant", content="52"), + Message(role="system", content="calc(52) = 204"), + Message( + role="system", + content="This is round 3 out of 10.\n\nSummary of the information you have already gathered:\ncalc(40) = 160\ncalc(52) = 204\n\nThink out loud about the following questions:\n* Do you have any hypothesis on what this function might be doing?\n* If yes, should you try to test it (how?), or just use it to calculate the answer?\n* If not, what additional information should you gather to be able to formulate a hypothesis?\n", + ), + Message( + role="assistant", + content="Now we have two results where the ouput is the input times 4. It seems that the function multiplies by 4. I will make the guess now. 17 * 4 = 68, 36 * 4 = 144 and 68 * 4 = 272, so my guess will be 68 144 272.", + ), + Message( + role="system", + content="Given the above reasoning, the answer in the format requested by the question is:", + ), + Message(role="assistant", content="68 144 272"), + Message(role="system", content="Correct guess!"), + Message( + role="system", + content="I now have a new function. Forget about the previous one, we start again.", + ), + ] + assert solver_private_memory[: len(expected_few_shot_msgs)] == expected_few_shot_msgs + assert ( + solver_private_memory[len(expected_few_shot_msgs) + 0].content == cot_template_first_round + ) + assert ( + solver_private_memory[len(expected_few_shot_msgs) + 2].content == solver._extract_template + ) + assert solver_private_memory[ + len(expected_few_shot_msgs) + 5 + ].content == cot_template_later_rounds.format( + round_ix=ROUNDS_SIMULATED, + num_rounds=N_ROUNDS, + summary=f"calc({GUESS_INPUT}) = {ANSWER}", + ) + assert ( + solver_private_memory[len(expected_few_shot_msgs) + 7].content == solver._extract_template + ) diff --git a/evals/elsuite/identifying_variables/.gitattributes b/evals/elsuite/identifying_variables/.gitattributes new file mode 100644 index 0000000000..e256da66cb --- /dev/null +++ b/evals/elsuite/identifying_variables/.gitattributes @@ -0,0 +1 @@ +images/*.png filter=lfs diff=lfs merge=lfs -text diff --git a/evals/elsuite/identifying_variables/README.md b/evals/elsuite/identifying_variables/README.md new file mode 100644 index 0000000000..59912f0b27 --- /dev/null +++ b/evals/elsuite/identifying_variables/README.md @@ -0,0 +1,177 @@ +# Identifying Variables + +This eval tests how well models can determine what should be treated as the +independent, dependent, and control variables for an experiment that tests a +particular hypothesis, given some observational context. + +## Usage + +Run with: + +```bash +oaieval identifying_variables +``` + +We have found that `generation/cot/gpt-4-1106-preview` works well on this eval. For more examples of tested solvers, see [`./scripts/run_experiments.sh`](./scripts/run_experiments.sh). + +## Evaluation Process + +The evaluation process is as follows for a given sample from our dataset: + +1. The `TASK_DESCRIPTION` prompt is shown to the solver. +2. The sample is passed through a _renderer_ that processes the samples and + renders an observation of the interactions of variables, which is placed in + the `SAMPLE_MESSAGE` prompt template. +3. The solver answers in the form: `[@ANSWER valid_hyp: ; independent: ; dependent: ; control: ]`. The answer is parsed and evaluated by the eval. If the answer cannot be parsed, we mark this as a violation and the sample is treated as incorrect. + +## Prompts + +We refer readers to the [`./prompts.py`](./prompts.py) file for the +`TASK_DESCRIPTION` and `SAMPLE_MESSAGE` prompts used in the eval. + +## Metrics + + +| **Metric** | **Notes** | +|---|---| +| `ctrl_nDCG` | A modified version of the [normalized discounted cumulative gains (nDCG)](https://en.wikipedia.org/wiki/Discounted_cumulative_gain#Normalized_DCG) metric, which rewards listing the correct control variables first and penalizes naming irrelevant variables. | +| `ctrl_recall` | Number of variables correctly marked as control variables / total number of variables to control according to the gold label | +| `ctrl_recall` | Number of variables incorrectly marked as control variables / total number of variables not to control according to the gold label | +| `hyp_valid_acc` | Target hypothesis plausibility validation accuracy (correct/incorrect) | +| `ind_acc` | Independent variable determination accuracy (correct/incorrect) | +| `dep_acc` | Dependent variable determination accuracy (correct/incorrect) | +| `violation_rate` | Number of samples with violations (model failed to answer in correct format) / total number of samples | + + +## Variants + +We support variations on the eval along two dimensions, `renderer` and `dataset`: + +```bash +oaieval identifying_variables.. +``` + +The eval defaults to `identifying_variables.language-corrset.balanced-ctrl`. + +### Dataset + +We provide 4 dataset variants: + +| `dataset` | Notes | +| --- | --- | +| `balanced-ctrl` | 500 samples balanced across number of control variables (from 0 to 8). | +| `balanced-ctrl-large` | As `balanced-ctrl`, but with 5,000 samples. | +| `balanced-hypotheses` | 500 samples balanced across target hypotheses being implausible/plausible. | +| `balanced-hypotheses-large` | As `balanced-hypotheses`, but with 5,000 samples. | + +### Renderers + +We have 6 different renderers, implemented in [`./renderers/`](./renderers/). + +The default renderer is `language-corrset`. Here is an example render from this type: +``` +The following is a description of the observations made about a set of variables. + +In general, there were cases where some variables changed in tandem with each other, while others did not. +For example, changes in x_5075 were observed to reflect changes in x_3314 and viceversa. +Changes in x_9549 were not observed to reflect any changes in previously mentioned variables. +Changes in x_1808 were not observed to reflect any changes in previously mentioned variables. +Likewise, changes in x_9726 were observed to reflect changes in x_1808 and viceversa. +``` + +### Show Tree + +We provide an additional variant of the eval where the decision tree implementing +the reasoning for scoring a perfect score is shown to the model. This variant +can be run by passing the `show_tree=True` flag to eval, e.g. + +```bash +oaieval identifying_variables --extra_eval_params show_tree=True +``` + +## Custom Solvers + +We implement two custom programmatic solvers to serve as baselines. + +1. `identifying_variables/random`: a solver that randomly selects whether the + hypothesis is plausible with probability 0.5, and if so randomly samples the + independent, dependent and control variables. We view this baseline as + equivalent to randomly guessing. +2. `identifying_variables/noctrl`: this is a solver that always outputs an empty + list for the variables to control, essentially eliminating any chance of + false positives. This can provide stronger performance than the random + baseline, since it avoids any penalization for returning incorrect variables, + and can even achieve a perfect score on samples that indeed do not have any + variables to control + +We refer to [`./solvers.py`](./solvers.py) for the implementation of these +solvers. + +## Token Usage Estimates + +We estimated per-run token usage on the default dataset size (500 samples) +for the least and most token-intensive configurations for each model type +(respectively, direct models on `identifying_variables.corrset` with +`show_tree=False`; and CoT models on `identifying_variables.language-tabular` +with `show_tree=True`). + + +| | **input tokens/run** | **output tokens/run** | **total tokens/run** | +|---|---|---|---| +| **GPT-4-base HHH (corrset, no tree)** | 1,200,000 | 250,000 | 1,450,000 | +| **GPT-4-base CoT HHH (language-tabular, with tree)** | 1,500,000 | 240,000 | 1,740,000 | +| **GPT-3.5-turbo Direct (corrset, no tree)** | 430,000 | 88,000 | 518,000 | +| **GPT-3.5-turbo CoT (language-tabular, with tree)** | 780,000 | 14,000 | 794,000 | +| **GPT-4-1106-preview Direct (corrset, no tree)** | 430,000 | 53,000 | 483,000 | +| **GPT-4-1106-preview CoT (language-tabular, with tree)** | 860,000 | 14,000 | 874,000 | + + +These estimates were taken using the `balanced-hypotheses` dataset but should +roughly apply to the `-balanced-ctrl` datasets. For `-large` datasets (5000 +samples), multiply the above numbers by 10. + +## Future modifications + +- Revisit the definition of the headline `ctrl_nDCG` metric +- Devise additional auxiliary metrics to paint a more complete picture +- What if we show the decision trees described in natural language rather than + pseudocode? +- How can we extend this eval to multi-variable dependencies? + +## Version History + +- v0: Initial version released + +## Contribution Statement + +Eval design, implementation, and results evaluation and writeup were primarily +conducted by Giulio Starace. James Aung was of enormous assistance in report +writing, and additionally provided general feedback and project management +throughout the eval. Oliver Jaffe and Jan Betley were responsible for code +reviews throughout the implementation process, along with fine-grained feedback +on the project in general. Additional guidance was provided by (alphabetically +by last-name) Steven Adler and Chan Jun Shern, who helped with brainstorming, +gave research input and report revisions. + +## Appendix + +### Perfect output decision trees + +The following decision trees are used to determine the perfect output (aka "gold +label") for a given sample. + +--- + + + +**Figure A1**: Decision tree for determining whether a given variable should be +controlled. + +--- + + + +**Figure A2**: Decision tree for determining a hypothesis is valid and if so +what the independent and dependent variables are. + +--- diff --git a/evals/elsuite/identifying_variables/constants.py b/evals/elsuite/identifying_variables/constants.py new file mode 100644 index 0000000000..60729828c7 --- /dev/null +++ b/evals/elsuite/identifying_variables/constants.py @@ -0,0 +1,19 @@ +# variables that have at least this amount of sparsity are considered to be unobserved +SPARSITY_FOR_UNOBS = 0.8 +# num of variables in a given sample +MIN_VARS = 2 +MAX_VARS = 10 +# num of hypotheses in a given sample +MIN_HYPS = 1 +MAX_HYPS = 3 +# sparse var rate: percentage of variables to sparsify +MIN_SPARSE_VAR_RATE = 0 +MAX_SPARSE_VAR_RATE = 1 +# sparsity: percentage of NaNs in a sparsified variable +MIN_SPARSITY = 0.2 +MAX_SPARSITY = 1 + +# specific to tabular renderers ------------ + +# num of observations +NUM_OBS = 20 diff --git a/evals/elsuite/identifying_variables/eval.py b/evals/elsuite/identifying_variables/eval.py new file mode 100644 index 0000000000..31b3b743e0 --- /dev/null +++ b/evals/elsuite/identifying_variables/eval.py @@ -0,0 +1,292 @@ +""" +Implementation logic for Identifying Variables eval +""" +import logging +import random +from dataclasses import asdict +from typing import Dict, List, Optional, Tuple + +import networkx as nx +import numpy as np + +from evals.elsuite.identifying_variables import constants, graph_utils, prompts +from evals.elsuite.identifying_variables.metrics import ( + compute_fallout, + compute_nDCG, + compute_recall, +) +from evals.elsuite.identifying_variables.renderers import RENDERER_MAP +from evals.elsuite.identifying_variables.scripts.gen_data import gen_samples +from evals.elsuite.identifying_variables.structs import Answer, Sample +from evals.elsuite.identifying_variables.utils import json_to_sample, parse_solver_preds +from evals.eval import SolverEval +from evals.record import RecorderBase, record_metrics +from evals.solvers.solver import Solver, SolverResult +from evals.task_state import Message, TaskState + +logging.getLogger("httpx").setLevel(logging.WARNING) + + +class IdentifyingVariables(SolverEval): + def __init__( + self, + renderer: str, + n_samples: Optional[int] = None, + show_tree: bool = False, + group_metrics: bool = False, + debug: bool = False, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.rng: random.Random = random.Random(self.seed) + self.np_rng: np.random.Generator = np.random.default_rng(self.seed) + self.renderer = RENDERER_MAP[renderer](rng=self.rng, np_rng=self.np_rng) + self.renderer_variant = renderer + self.n_samples = n_samples + self.show_tree = show_tree + self.task_description = self._build_task_description() + self.group_metrics = group_metrics + self.debug = debug + + def _build_task_description(self) -> str: + decision_tree_section = "" + if self.show_tree: + decision_tree_section = prompts.DECISION_TREE_SECTION + return prompts.TASK_DESCRIPTION.format( + optional_decision_tree_section=decision_tree_section, + ).strip() + + def eval_sample(self, solver: Solver, sample: Sample, rng: random.Random) -> None: + message: Message = self._build_message(sample) + + task_state = TaskState( + task_description=self.task_description, + messages=[message], + # to be used by the Random baseline solver only + current_state={"variables": [var for var in sample.causal_graph.nodes]}, + ) + + solver_result: SolverResult = solver(task_state) + + try: + preds = parse_solver_preds(solver_result) + except ValueError: # in case of invalid solver output + preds = None + gold, num_not_ctrl = sample.gold_label, sample.num_not_ctrl + + metrics: Dict[str, float] = self._evaluate_sample(preds, gold, num_not_ctrl) + + record_metrics( + **metrics, + # hack: logviz doesn't support custom log fields, so logging as metric + causal_graph=nx.to_dict_of_lists(sample.causal_graph), + gold_answer=asdict(gold), + n_hyps=sample.hypotheses.number_of_edges(), + valid_hyp=gold.valid_hypothesis, + num_not_ctrl=num_not_ctrl, + ) + + def run(self, recorder: RecorderBase) -> Dict[str, float]: + samples: List[Dict] = self._get_samples() + self.rng.shuffle(samples) + self.eval_all_samples(recorder, samples) + metrics: List[Dict] = recorder.get_metrics() + + return self._compute_agg_metrics(metrics) + + def _compute_agg_metrics(self, metrics: List[Dict]) -> Dict[str, float]: + """ + Computes aggregate metrics across all samples + """ + main_metrics = { + "hyp_valid_acc": np.mean([x["hyp_valid_correct"] for x in metrics]), + "violation_count": np.sum([x["violation"] for x in metrics]), + "violation_rate": np.mean([x["violation"] for x in metrics]), + # Some samples may be NaN for cases where the target hypothesis is invalid + "ctrl_nDCG": np.nanmean([x["ctrl_nDCG"] for x in metrics]), + "ctrl_recall": np.nanmean([x["ctrl_recall"] for x in metrics]), + "ctrl_fallout": np.nanmean([x["ctrl_fallout"] for x in metrics]), + "ind_acc": np.nanmean([x["ind_correct"] for x in metrics]), + "dep_acc": np.nanmean([x["dep_correct"] for x in metrics]), + "n_valid_hyp": np.sum([x["valid_hyp"] for x in metrics]), + } + if self.group_metrics: + grouped_metrics = self._compute_grouped_metrics(metrics) + else: + grouped_metrics = {} + + total_metrics = {**main_metrics, **grouped_metrics} + total_metrics = {k: float(v) for k, v in total_metrics.items()} + return total_metrics + + def _compute_grouped_metrics(self, metrics: List[Dict]) -> Dict[str, float]: + """ + Computes metrics aggregated across samples grouped by + - number of variables + - number of roots in random forest + - number of control variables + - number of hypotheses + - max correlation depth + """ + metric_to_agg_func = { + "hyp_valid_acc": np.mean, + "violation_count": np.sum, + "violation_rate": np.mean, + "ctrl_nDCG": np.nanmean, + "ctrl_recall": np.nanmean, + "ctrl_fallout": np.nanmean, + "ind_acc": np.nanmean, + "dep_acc": np.nanmean, + } + raw_metric_names = [ + "hyp_valid_correct", + "violation", + "violation", + "ctrl_nDCG", + "ctrl_recall", + "ctrl_fallout", + "ind_correct", + "dep_correct", + ] + group_to_bins = { + "n_vars": np.arange(constants.MIN_VARS, constants.MAX_VARS + 1), + "n_roots": np.arange(1, constants.MAX_VARS + 1), + "n_ctrl_vars": np.arange(0, (constants.MAX_VARS - 2) + 1), + "n_hyps": np.arange(constants.MIN_HYPS, constants.MAX_HYPS + 1), + "max_corr_depth": np.arange(1, constants.MAX_VARS), + } + grouped_metrics = { + f"{metric}-{group}-{g_bin}": [] + for metric in metric_to_agg_func.keys() + for group in group_to_bins.keys() + for g_bin in group_to_bins[group] + } + for log_entry in metrics: + causal_graph = nx.from_dict_of_lists(log_entry["causal_graph"], create_using=nx.DiGraph) + ctrl_vars = log_entry["gold_answer"]["ctrl_vars"] + dep_var = log_entry["gold_answer"]["dep_var"] + group_to_bin = { + "n_vars": causal_graph.number_of_nodes(), + "n_roots": len(graph_utils.find_graph_roots(causal_graph)), + "n_ctrl_vars": len(ctrl_vars) if ctrl_vars is not None else None, + "n_hyps": log_entry["n_hyps"], + "max_corr_depth": graph_utils.find_farthest_node(causal_graph, dep_var)[1] + if dep_var is not None + else None, + } + for group, g_bin in group_to_bin.items(): + if g_bin is not None: + for metric, raw_metric in zip(metric_to_agg_func.keys(), raw_metric_names): + grouped_metrics[f"{metric}-{group}-{g_bin}"].append(log_entry[raw_metric]) + + # aggregate + grouped_metrics = { + k: metric_to_agg_func[k.split("-")[0]](v) + # signal empty groups with np.nan + if len(v) > 0 else np.nan + for k, v in grouped_metrics.items() + } + return grouped_metrics + + def _evaluate_sample(self, preds: Optional[Answer], gold: Answer, num_not_ctrl: int) -> Dict: + """ + If the gold hypothesis is invalid, then all other metrics are skipped, and we + only evaluate whether the solver correctly identified the hypothesis as invalid. + + Mistakes are propagated: If the solver incorrectly identifies a hypothesis as + invalid, then its missing answers for the remaining tasks are counted as wrong. + + In case of violations, the worst possible metrics are recorded, accounting for + the gold hypothesis validity caveat above (e.g. if the gold hypothesis is + invalid, then the worst case ctrl_nDCG is NaN since we'd skip this anyway, + whereas if the gold hypothesis were valid, then the worst case ctrl_nDCG would + be 0.0) + """ + hyp_valid_correct = preds.valid_hypothesis == gold.valid_hypothesis if preds else False + + if gold.valid_hypothesis: + ind_correct = preds.ind_var == gold.ind_var if preds else False + dep_correct = preds.dep_var == gold.dep_var if preds else False + ctrl_nDCG = ( + self._ctrl_vars_nDCG(preds.ctrl_vars, gold.ctrl_vars, num_not_ctrl) + if preds and preds.ctrl_vars is not None + else 0.0 + ) + ctrl_recall = ( + self._ctrl_vars_recall(preds.ctrl_vars, gold.ctrl_vars) + if preds and preds.ctrl_vars is not None + else 0.0 + ) + # not in final report, since experiments had already been run + ctrl_fallout = ( + self._ctrl_vars_fallout(preds.ctrl_vars, gold.ctrl_vars, num_not_ctrl) + if preds and preds.ctrl_vars is not None + else 1.0 + ) + + else: + ctrl_nDCG = np.nan + ctrl_recall = np.nan + ctrl_fallout = np.nan + ind_correct = np.nan + dep_correct = np.nan + + return { + "ctrl_nDCG": ctrl_nDCG, + "ctrl_recall": ctrl_recall, + "ctrl_fallout": ctrl_fallout, + "ind_correct": ind_correct, + "dep_correct": dep_correct, + "hyp_valid_correct": hyp_valid_correct, + "violation": preds is None, + } + + def _ctrl_vars_fallout(self, preds: List[str], gold: List[str], num_not_ctrl: int) -> float: + return compute_fallout(set(preds), set(gold), num_not_ctrl) + + def _ctrl_vars_recall(self, preds: List[str], gold: List[str]) -> float: + return compute_recall(set(preds), set(gold)) + + def _ctrl_vars_nDCG(self, preds: List[str], gold: List[str], num_not_ctrl: int) -> float: + best = [1.0] * len(gold) + ranking = [1.0 if var in gold else -1.0 for var in preds] + worst_case_ctrl = [-1.0] * num_not_ctrl + return compute_nDCG(ranking, best, worst_case_ctrl) + + def _build_message(self, sample: Sample) -> Message: + observations: str = self.renderer.render_obs(sample) + hypotheses: List[str] = self._render_hypotheses(sample.hypotheses) + target_hypothesis: str = self._render_hypothesis(sample.target_hypothesis) + + message_content = prompts.SAMPLE_MESSAGE.format( + observations=observations, + hypotheses=hypotheses, + target_hypothesis=target_hypothesis, + ).strip() + message = Message("user", content=message_content) + + return message + + def _render_hypotheses(self, hypotheses: nx.DiGraph) -> List[str]: + hyp_list = [(n, adj) for n in hypotheses for adj in hypotheses[n]] + return [self._render_hypothesis(h) for h in hyp_list] + + def _render_hypothesis(self, hypothesis: Tuple[str, str]) -> str: + hyp_template = self.rng.choice(prompts.hypothesis_templates) + rendered_hyp = hyp_template.format(ind=hypothesis[0], dep=hypothesis[1]) + return rendered_hyp + + def _get_samples(self) -> List[Sample]: + if self.debug: + return gen_samples(n_samples=1000, signal_noise_ratio=None, np_rng=self.np_rng) + + dict_samples = self.get_samples() + if self.n_samples is not None: + assert ( + len(dict_samples) >= self.n_samples + ), f"Can't get {self.n_samples} samples from a dataset with {len(dict_samples)} samples" + np.random.default_rng(seed=self.seed).shuffle(dict_samples) + dict_samples = dict_samples[: self.n_samples] + samples = [json_to_sample(dict_sample) for dict_sample in dict_samples] + return samples diff --git a/evals/elsuite/identifying_variables/graph_utils.py b/evals/elsuite/identifying_variables/graph_utils.py new file mode 100644 index 0000000000..815ab968cc --- /dev/null +++ b/evals/elsuite/identifying_variables/graph_utils.py @@ -0,0 +1,254 @@ +"""Utils for network graph related operations.""" +from typing import Any, List, Optional, Set, Tuple, Union + +import networkx as nx +import numpy as np + + +def val_and_count_roots( + nodes: List[str], + np_rng: np.random.Generator, + total_edges: Optional[int] = None, + min_roots: Optional[int] = None, +) -> int: + """ + Validates the parameters for the construction of a random forest via + `gen_random_forest` and determines the min number of roots to use. + + A random forest following the constraints of `gen_random_forest` with + N nodes will have + - R <= N roots + - E <= N - R edges + If min_roots is not specified, then E <= N - 1, since R >= 1. + """ + n_nodes = len(nodes) + if min_roots is not None: + assert min_roots <= n_nodes, "Total roots must be less than or equal to the number of nodes" + if total_edges is not None: + assert ( + 0 <= total_edges <= n_nodes - min_roots + ), "Total edges must be between 0 and the number of nodes minus the number of roots" + else: + if total_edges is None: + min_roots = np_rng.integers(1, n_nodes + 1) + else: + assert ( + 0 <= total_edges <= n_nodes - 1 + ), "Total edges must be between 0 and the number of nodes minus 1" + # if total edges is specified, then we have an upper bound on R, R <= N - E + max_roots = n_nodes - total_edges + min_roots = np_rng.integers(1, max_roots + 1) + + return min_roots + + +def gen_random_forest_tree_size( + nodes: List[str], + tree_size: int, + np_rng: Optional[np.random.Generator] = None, +) -> nx.DiGraph: + """ + Builds a random forest, i.e. a Directed Acyclic Graph (DAG) + with potentially more than one root. + + We enforce the following constraints for our purposes: + 1. No self connections + 2. No bi-directional connections + 3. No children with multiple parents + 4. At least one root node (no parents) + 5. No cycles + + We additionally allow the user to specify the size that at least one + of the trees in the forest should be. + + Args: + nodes: A list of node names to build the graph from + tree_size: The number of nodes that at least one of the trees in the forest + should have + np_rng: A numpy random number generator + """ + num_nodes = len(nodes) + assert tree_size <= num_nodes, "Tree size must be less than or equal to the number of nodes" + + max_number_roots = num_nodes - tree_size + 1 + min_number_roots = 1 # 1 root is always reserved to the tree of size tree_size + + np_rng = np_rng or np.random.default_rng() + + num_roots = np_rng.integers(min_number_roots, max_number_roots + 1) + roots = set(np_rng.choice(nodes, num_roots, replace=False).tolist()) + + size_controlled_root = np_rng.choice(list(roots)) + size_controlled_tree_nodes = {size_controlled_root} + + shuffled_nodes = np_rng.permutation(nodes) + + graph_children = set() + + graph = nx.DiGraph() + graph.add_nodes_from(shuffled_nodes) + + while len(size_controlled_tree_nodes) < tree_size: + possible_children = [ + n for n in nodes if n not in size_controlled_tree_nodes and n not in roots + ] + child = np_rng.choice(possible_children) + possible_parents = list(size_controlled_tree_nodes) + parent = np_rng.choice(possible_parents) + graph.add_edge(parent, child) + size_controlled_tree_nodes.add(child) + graph_children.add(child) + + remaining_nodes = set(nodes) - size_controlled_tree_nodes + + for node in remaining_nodes: + possible_children = [ + n + for n in remaining_nodes + # avoid self connections + if n != node and + # avoid cycles and bi-directional conns -> ancestors can't be children + n not in nx.ancestors(graph, node) and + # avoid children with multiple parents + n not in graph_children and + # roots can't be children + n not in roots + ] + num_edges = np_rng.integers(0, len(possible_children) + 1) + children = np_rng.choice(possible_children, num_edges, replace=False).tolist() + + for child in children: + graph.add_edge(node, child) + graph_children.update(children) + + return graph + + +def gen_random_forest( + nodes: List[str], + total_edges: Optional[int] = None, + min_roots: Optional[int] = None, + np_rng: Optional[np.random.Generator] = None, +) -> nx.DiGraph: + """ + Builds a random forest, i.e. a Directed Acyclic Graph (DAG) + with potentially more than one root. + + We enforce the following constraints for our purposes: + 1. No self connections + 2. No bi-directional connections + 3. No children with multiple parents + 4. At least one root node (no parents) + 5. No cycles + + Args: + nodes: A list of node names to build the graph from + total_edges: The total number of edges in the graph. If None, will be random. + min_roots: The minimum number of roots in the graph. If None, will be random. + """ + np_rng = np_rng or np.random.default_rng() + graph = nx.DiGraph() + graph.add_nodes_from(nodes) + + min_roots = val_and_count_roots(nodes, np_rng, total_edges, min_roots) + + # the minimal set of roots, there may be more as we create the graph + roots = set(np_rng.choice(nodes, min_roots, replace=False).tolist()) + + graph_children = set() + edge_count = 0 + + shuffled_nodes = np_rng.permutation(nodes) + + for node in shuffled_nodes: + possible_children = [ + n + for n in nodes + # avoid self connections + if n != node and + # avoid cycles and bi-directional conns -> ancestors can't be children + n not in nx.ancestors(graph, node) and + # avoid children with multiple parents + n not in graph_children and + # roots can't be children + n not in roots + ] + + if len(possible_children) == 0: + continue + + if total_edges is not None: + remaining_edges = total_edges - edge_count + if remaining_edges <= 0: + break + num_edges = np_rng.integers(0, min(remaining_edges, len(possible_children)) + 1) + else: + num_edges = np_rng.integers(0, len(possible_children) + 1) + + children = np_rng.choice(possible_children, num_edges, replace=False).tolist() + + for child in children: + graph.add_edge(node, child) + graph_children.update(children) + edge_count += num_edges + + if total_edges is not None and edge_count < total_edges: + # If we didn't reach the total number of edges, try again + return gen_random_forest(nodes, total_edges, min_roots, np_rng) + + return graph + + +def find_farthest_node(graph: nx.DiGraph, source: str) -> Tuple[str, int]: + """ + Performs Breadth-First Search (BFS) to find the farthest node from the source node + and the distance to that node. Distance is defined as the number of edges between + the source node and the farthest node. + """ + graph = graph.to_undirected() + + # Compute shortest path lengths from source to all other nodes + path_lengths = nx.single_source_shortest_path_length(graph, source) + + # Find the farthest node + farthest_node = max(path_lengths, key=path_lengths.get) + max_distance = path_lengths[farthest_node] + + return farthest_node, max_distance + + +def find_graph_roots(graph: nx.DiGraph) -> Set[str]: + """ + Finds the root nodes of a graph + """ + return set([n for n, d in graph.in_degree() if d == 0]) + + +def find_graph_trees(graph: nx.DiGraph) -> List[Set[str]]: + """ + Finds the trees of a graph + """ + return [{root, *nx.descendants(graph, root)} for root in find_graph_roots(graph)] + + +def find_connected_nodes_pair( + graph: nx.DiGraph, np_rng: np.random.Generator +) -> Union[Tuple[Any, Any], None]: + """ + Finds a pair of connected nodes in a graph + If no such pair exists, returns None + """ + connected_pair = tuple(np_rng.choice(list(graph.edges))) if graph.edges else None + return connected_pair + + +def find_unconnected_nodes_pair(graph: nx.DiGraph) -> Union[Tuple[Any, Any], None]: + """ + Finds a pair of unconnected nodes in a graph + If no such pair exists, returns None + """ + components = list(nx.connected_components(graph.to_undirected())) + + if len(components) > 1: + return next(iter(components[0])), next(iter(components[1])) + return None diff --git a/evals/elsuite/identifying_variables/images/control_var_tree.png b/evals/elsuite/identifying_variables/images/control_var_tree.png new file mode 100755 index 0000000000..59de243e29 --- /dev/null +++ b/evals/elsuite/identifying_variables/images/control_var_tree.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:60bbedac103bae669c4cec1037faaa18b87df63ab5d2c61734f2c60211240fd6 +size 273556 diff --git a/evals/elsuite/identifying_variables/images/valid_hyp_tree.png b/evals/elsuite/identifying_variables/images/valid_hyp_tree.png new file mode 100644 index 0000000000..d005e47b47 --- /dev/null +++ b/evals/elsuite/identifying_variables/images/valid_hyp_tree.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:758a23f6b4bd7676852af320b28f8b6af61c404d22835eda99f2b8dc89a0277b +size 69394 diff --git a/evals/elsuite/identifying_variables/latent_funcs.py b/evals/elsuite/identifying_variables/latent_funcs.py new file mode 100644 index 0000000000..6f66a1c44e --- /dev/null +++ b/evals/elsuite/identifying_variables/latent_funcs.py @@ -0,0 +1,43 @@ +"""Latent functions for the project.""" +import numpy as np + + +def linear(x: np.ndarray, grad: float, bias: float) -> np.ndarray: + return grad * x + bias + + +def quadratic(x: np.ndarray, grad: float, bias: float) -> np.ndarray: + return grad * x**2 + bias + + +def random_uniform(num_samples, min_v, max_v, rng: np.random.Generator) -> np.ndarray: + return rng.uniform(min_v, max_v, num_samples) + + +def random_ints(num_samples, min_v, max_v, rng: np.random.Generator) -> np.ndarray: + return rng.integers(min_v, max_v, num_samples) + + +LATENT_FUNC_MAP = { + "linear": linear, + "quadratic": quadratic, +} +LATENT_FUNC_KWARG_MAP = { + "linear": { + "grad": {"min_v": -10, "max_v": 10}, + "bias": {"min_v": -100, "max_v": 100}, + }, + "quadratic": { + "grad": {"min_v": -10, "max_v": 10}, + "bias": {"min_v": -100, "max_v": 100}, + }, +} + +DISTRIBUTIONS = { + # "random_uniform": random_uniform, + "random_ints": random_ints, +} +DISTRIBUTIONS_KWARG_MAP = { + "random_uniform": {"min_v": -1, "max_v": 1}, + "random_ints": {"min_v": -100, "max_v": 100}, +} diff --git a/evals/elsuite/identifying_variables/metrics.py b/evals/elsuite/identifying_variables/metrics.py new file mode 100644 index 0000000000..501ec3b1a9 --- /dev/null +++ b/evals/elsuite/identifying_variables/metrics.py @@ -0,0 +1,105 @@ +from typing import Dict, List, Set + +import numpy as np + +from evals.elsuite.identifying_variables.utils import parse_solver_preds +from evals.solvers.solver import SolverResult + + +def compute_DCG(ranking: List[float], ceil_negs: bool = False) -> float: + """ + Computes the DCG of a ranking + """ + dcg = 0 + for i, rel in enumerate(ranking, start=1): + if ceil_negs: + rel = max(rel, 0) + dcg += rel / np.log2(i + 1) # (i+1) to avoid log_2(1) which = 0 + return dcg + + +def compute_nDCG(ranking: List[float], best: List[float], worst: List[float]) -> float: + """ + Computes nDCG, allowing for negative scores, based on the nDCG variant + from Gienapp et al. (2020) (https://dl.acm.org/doi/10.1145/3340531.3412123) + """ + idcg = compute_DCG(best) + min_dcg = compute_DCG(worst) + dcg = compute_DCG(ranking) + return (dcg - min_dcg) / (idcg - min_dcg) + + +def compute_metric_posthoc( + metric: str, metric_entries: List[Dict], sampling_entries: List[Dict] +) -> float: + """ + Computes a metric that was not logged by the eval, post-hoc, i.e. + after the eval has run, by reading the log file. + """ + metric_to_func = { + "ctrl_recall": compute_ctrl_recall_posthoc, + } + if metric not in metric_to_func.keys(): + raise ValueError(f"Metric {metric} not supported") + return metric_to_func[metric](metric_entries, sampling_entries) + + +def compute_ctrl_recall_posthoc(metric_entries: List[Dict], sampling_entries: List[Dict]) -> float: + """ + Computes the average recall for identified control variables + + i.e. the no. of correctly identified control variables / no. gold control variables + Averaged across the samples. + + - We skip any samples where the gold hypothesis is invalid + - And we skip any samples where there are no control variables in the gold label, + since recall is undefined in this case + """ + recalls = [] + for metric_entry, sampling_entry in zip(metric_entries, sampling_entries): + try: + preds = parse_solver_preds(SolverResult(output=sampling_entry["sampled"][0])) + except ValueError: # in case of invalid solver output (violation) + preds = None + + if metric_entry["gold_answer"]["valid_hypothesis"]: + if preds and preds.ctrl_vars is not None: + recall = compute_recall( + set(preds.ctrl_vars), set(metric_entry["gold_answer"]["ctrl_vars"]) + ) + else: + # worst case scenario in case of violation or incorrect hyp validation + recall = 0 + else: + recall = np.nan + recalls.append(recall) + return np.nanmean(recalls).astype(float) + + +def compute_fallout(retrieved: Set[str], gold_relevants: Set[str], num_irrelevant: int) -> float: + """ + Computes fallout for a sample + Number of retrieved irrelevant items / number of irrelevant items + + This value is undefined when there are no irrelevant items in the gold label + """ + if num_irrelevant == 0: + # undefined + return np.nan + retrieved_irrel_count = len([r for r in retrieved if r not in gold_relevants]) + return retrieved_irrel_count / num_irrelevant + + +def compute_recall(retrieved: Set[str], gold_relevants: Set[str]): + """ + Computes recall for a sample + Number of retrieved relevant items / number of relevant items + + This value is undefined when there are no relevant items in the gold label + """ + num_relevant = len(gold_relevants) + if num_relevant == 0: + # undefined + return np.nan + retrieved_rel_count = len([r for r in retrieved if r in gold_relevants]) + return retrieved_rel_count / num_relevant diff --git a/evals/elsuite/identifying_variables/prompts.py b/evals/elsuite/identifying_variables/prompts.py new file mode 100644 index 0000000000..4c2eecefda --- /dev/null +++ b/evals/elsuite/identifying_variables/prompts.py @@ -0,0 +1,103 @@ +TASK_DESCRIPTION = """\ +Your task is to determine what the correct independent, dependent\ + and control variables should be when designing an experiment for validating a specific\ + hypothesis, given some observations. + +The observations are organized into N different variables describing some environment.\ + The rules of this environment are unknown. The order in which the variables are\ + presented is arbitrary, and so are their names. In some cases, some if not all\ + observations for a variable may be missing, which simply means that those observations\ + have not been collected for that variable. + +You will also be presented with a list of hypotheses regarding the effect on some\ + variable by some other variable(s). From this list, a "target hypothesis" is chosen for\ + you to focus your choice of variables on. + +You should determine the independent, dependent and control variables based on the above. + +If there is more than one control variable, please separate them with commas, and order\ + them from most important to least important. If there are no control variables, output\ + "None" for the control variable field. Note that you are budget-constrained, so\ + you should only list control variables which would be quite risky not to control, given\ + what is known. + +If the target hypothesis is unlikely to be worth investigating, it is considered\ + invalid, and you should determine this as part of your answer. + +Your answer should be in the following format, including the square brackets: + +[@ANSWER valid_hyp: ; independent: ; dependent: ; control: ] + +For example, \ + [@ANSWER valid_hyp: true; independent: x_1; dependent: x_2; control: x_3, x_4] + +In case of invalid hypothesis, you may ignore the other fields\ + e.g. [@ANSWER valid_hyp: false; independent: NA; dependent: NA; control: NA] or \ + [@ANSWER valid_hyp: false] + +Only perform reasoning if requested. + +{optional_decision_tree_section} +""" + + +DECISION_TREE_SECTION = """\ +To determine the validity and thereby the independent and dependent variables of a\ + hypothesis stating that x_1 is a function of x_2, you can refer to the following\ + pseudocode + +```pseudocode +if correlated(x_1, x_2) is not None and not correlated(x_1, x_2): + "invalid hypothesis" +else: + "independent: x_2; dependent: x_1" +``` + +where `correlated` returns `True` if its arguments are correlated `False` if not,\ + and `None` if it is unknown. + +To determine whether a given variable x_n should be a control variable in an experiment\ + testing a (valid) hypothesis stating that x_1 is a function of x_2, you can refer to the\ + following pseudocode: + +```pseudocode +if x_n in {x_1, x_2}: + "do not control for x_n" +else: + if correlated(x_1, x_n) or correlated(x_2, x_n): + "control for x_n" + else: + if correlated(x_1, x_n) is not None: + "do not control for x_n" + else: + if hypothesized(ind=x_n, dep=x_1, allow_indirect=True): + "control for x_n" + else: + "do not control for x_n" +``` + +where `hypothesized` returns whether `ind` is hypothesized to be a cause of `dep`,\ + even indirectly through chains of hypotheses. +""" + + +SAMPLE_MESSAGE = """\ +Observations: + +{observations} + +Hypotheses: + +{hypotheses} + +Target Hypothesis: + +{target_hypothesis} +""" + + +hypothesis_templates = [ + "{dep} is a function of {ind}", + "{ind} affects {dep} through some function", + "{dep} is affected by {ind} through some function", +] diff --git a/evals/elsuite/identifying_variables/renderers/__init__.py b/evals/elsuite/identifying_variables/renderers/__init__.py new file mode 100644 index 0000000000..c155624761 --- /dev/null +++ b/evals/elsuite/identifying_variables/renderers/__init__.py @@ -0,0 +1,11 @@ +from . import tabular +from . import corrset + +RENDERER_MAP = { + "markdown": tabular.MarkdownTableRenderer, + "csv": tabular.CSVTableRenderer, + "json": tabular.JSONTableRenderer, + "language-tabular": tabular.LanguageTableRenderer, + "language-corrset": corrset.LanguageCorrSetRenderer, + "corrset": corrset.PureCorrSetRenderer, +} diff --git a/evals/elsuite/identifying_variables/renderers/base.py b/evals/elsuite/identifying_variables/renderers/base.py new file mode 100644 index 0000000000..90c1d27ae5 --- /dev/null +++ b/evals/elsuite/identifying_variables/renderers/base.py @@ -0,0 +1,16 @@ +import abc +import random + +import numpy as np + +from evals.elsuite.identifying_variables.structs import Sample + + +class RendererBase(abc.ABC): + def __init__(self, rng: random.Random, np_rng: np.random.Generator) -> None: + self.rng = rng + self.np_rng = np_rng + + @abc.abstractmethod + def render_obs(self, sample: Sample) -> str: + raise NotImplementedError diff --git a/evals/elsuite/identifying_variables/renderers/corrset.py b/evals/elsuite/identifying_variables/renderers/corrset.py new file mode 100644 index 0000000000..39563527a6 --- /dev/null +++ b/evals/elsuite/identifying_variables/renderers/corrset.py @@ -0,0 +1,346 @@ +from typing import List, Set, Tuple + +from evals.elsuite.identifying_variables.structs import Sample +from evals.elsuite.identifying_variables.renderers.base import RendererBase +import evals.elsuite.identifying_variables.graph_utils as graph_utils +import evals.elsuite.identifying_variables.renderers.templates as templates +from evals.elsuite.identifying_variables.constants import SPARSITY_FOR_UNOBS + + +class CorrSetRenderer(RendererBase): + """ + Describes the correlation structure of variables + """ + + def determine_sample_type(self, sample: Sample) -> Tuple[str, List[Set[str]]]: + """ + Determines the type of sample we have, returning the correlation sets in + the process. Accounts for unobserved variables by removing them from + the correlation sets. + + Returns: + str: The type of causal graph we have, ignoring unobserved variables. + Either + - "many_correl_sets": there are at least two correlation sets, at least + one of which has at least two variables. + - "single_correl_set": there is only one correlation set. + - "only_ind": there are at least two correlation sets, all of which + have exactly one variable. + List[Set[str]]: The list of correlation sets. A correlation set is the + set of observed variables in a tree from the causal graph + """ + causal_graph = sample.causal_graph + graph_trees = graph_utils.find_graph_trees(causal_graph) + correl_sets = [] + unobserved_vars = set( + var + for var in sample.variable_metadata + if sample.variable_metadata[var]["extra"]["sparsity_rate"] + > SPARSITY_FOR_UNOBS + ) + for tree in graph_trees: + correl_set = set(tree) + for var in tree: + if var in unobserved_vars: + # correlations to unobserved variables are, well, unobserved + correl_set.remove(var) + correl_sets.append(correl_set) + # need to check for empty sets, since we removed unobserved variables + correl_sets = [correl_set for correl_set in correl_sets if len(correl_set) > 0] + if len(correl_sets) == 1: + return "single_correl_set", correl_sets + else: + for correl_set in correl_sets: + if len(correl_set) > 1: + # at least one set with more than one observed var + return "many_correl_sets", correl_sets + # all sets have only one node + return "only_ind", correl_sets + + def _get_hypd_unobserved_vars(self, sample: Sample) -> List[str]: + vars_to_mention = [] + hypotheses = sample.hypotheses + + hypothesized_vars = set( + var + for var in hypotheses + if hypotheses.in_degree(var) > 0 or hypotheses.out_degree(var) > 0 + ) + vars_to_mention = [ + var + for var in hypothesized_vars + if sample.variable_metadata[var]["extra"]["sparsity_rate"] + > SPARSITY_FOR_UNOBS + ] + return vars_to_mention + + +class PureCorrSetRenderer(CorrSetRenderer): + def render_obs(self, sample: Sample) -> str: + _, observed_sets = self.determine_sample_type(sample) + + render_string = ( + "The following correlation sets were observed. Variables in the" + " same correlation set are correlated with each other, but not with variables in" + " other correlation sets." + ) + render_string += "\n\n" + self._render_observed_sets(observed_sets) + render_string += "\n\n" + self._render_unobserved_vars(sample) + + return render_string + + def _render_observed_sets(self, observed_sets: List[Set[str]]) -> str: + """ + Renders the observed sets. + """ + render_string = "" + for idx, correl_set in enumerate(observed_sets, start=1): + render_string += f"\nCorrelation set {idx}: {{{', '.join(correl_set)}}}." + return render_string.strip() + + def _render_unobserved_vars(self, sample: Sample) -> str: + """ + Renders the unobserved variables. + """ + unobserved_variables = self._get_hypd_unobserved_vars(sample) + if len(unobserved_variables) == 0: + render_string = "There were no unobserved variables." + else: + render_string = f"Unobserved variables: [{', '.join(unobserved_variables)}]." + return render_string.strip() + + +class LanguageCorrSetRenderer(CorrSetRenderer): + """ + Describes the correlation structure of variables in natural language. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.type_to_renderer = { + "many_correl_sets": self.render_many_sets, + "single_correl_set": self.render_single_set, + "only_ind": self.render_only_ind, + } + + def render_obs(self, sample: Sample) -> str: + """ + Describes the interactions between variables in the sample. + + The description looks like + ``` + {opening statement} + + {description of the interactions} + + {optional mention of unobserved variables that were hypothesized about} + ``` + + The description of the interactions depends on the type of causal graph. + """ + sample_type, observed_sets = self.determine_sample_type(sample) + + opening_statement = templates.OPENING_STATEMENT + main_observation = self.type_to_renderer[sample_type](observed_sets) + unobserved_variables = self.mention_unobserved_vars(sample) + return "\n\n".join([opening_statement, main_observation, unobserved_variables]) + + def render_many_sets(self, correl_sets: List[Set[str]]): + """ + Renders a causal graph where we have at least two correlation + sets, one of which has at least two variables. + The description looks like: + ``` + In general, there were cases where some variables changed in tandem with each + other, while others did not. + {example of two variables that changed in tandem} + {interleaved mentions of remaining variables, specifying which other already + mentioned variables they changed in tandem with, if any} + ``` + """ + # Sort the sets by size, largest first + correl_sets = sorted(correl_sets, key=lambda x: len(x), reverse=True) + variables = [var for correl_set in correl_sets for var in correl_set] + + correl_set_idx_to_already_mentioned_vars = [set() for _ in correl_sets] + var_to_correl_set_idx = { + var: idx for idx, correl_set in enumerate(correl_sets) for var in correl_set + } + return_string = templates.MANY_CORREL_SETS_MAIN + + # hard-code mention first two variables, from first (largest) set + current_set_idx = 0 + return_string += "\n" + templates.CORREL_VARS_EXAMPLE.format( + optional_transition="For example, ", + # the first set is guaranteed to have at least two variables + var_1=variables[0], + var_2=variables[1], + ) + correl_set_idx_to_already_mentioned_vars[0].update([variables[0], variables[1]]) + + # go through remaining variables, randomly + variables = variables[2:] + self.rng.shuffle(variables) + + for var in variables: + correl_set_idx = var_to_correl_set_idx[var] + if correl_set_idx == current_set_idx: + transition_word = self.rng.choice(["Similarly", "Likewise"]) + transition_phrase = f"{transition_word}, " + else: + transition_phrase = "" + current_set_idx = correl_set_idx + + mentioned_vars_from_set = correl_set_idx_to_already_mentioned_vars[ + correl_set_idx + ] + if len(mentioned_vars_from_set) == 0: # first time mentioning this set + mention_string = templates.IND_VARS_EXAMPLE.format( + optional_transition=transition_phrase, + var_1=var, + var_2="previously mentioned variables", + ) + else: # variables from this set have been mentioned + mention_string = templates.CORREL_VARS_EXAMPLE.format( + optional_transition=transition_phrase, + var_1=var, + var_2=templates.list_to_nl_list(list(mentioned_vars_from_set)), + ) + return_string += "\n" + mention_string.capitalize() + # we have now mentioned this variable + correl_set_idx_to_already_mentioned_vars[correl_set_idx].add(var) + + return return_string + + def render_single_set(self, correl_sets: List[Set[str]]) -> str: + """ + Renders a causal graph where we have only one correlation set. + By definition, this set has at least two variables. + The description looks like: + ``` + In general, all of the variables seemed to change in tandem with each other. + For example, changes in {var_1} were observed to reflect changes in {var_2} and + viceversa. + {optional example of other pair} + {optional concluding statement that this holds for all pairs} + ``` + """ + correl_set = correl_sets[0] + # we won't use more than 3 variables in the examples. + exemplar_vars = list(correl_set)[:3] + remaining_vars = correl_set - set(exemplar_vars) + # always have at least 2 vars + example_1 = templates.CORREL_VARS_EXAMPLE.format( + optional_transition="", + var_1=exemplar_vars[0], + var_2=exemplar_vars[1], + ) + example_2 = "" + concluding_statement = "" + if len(exemplar_vars) == 3: + example_2 = templates.CORREL_VARS_EXAMPLE.format( + optional_transition="Additionally, ", + var_1=exemplar_vars[2], + var_2=templates.list_to_nl_list(exemplar_vars[:2]), + ) + if len(remaining_vars) > 0: + concluding_statement = templates.SPECIFIC_CONCL_STATEMENT.format( + already_mentioned=templates.list_to_nl_list(exemplar_vars), + remaining_vars=templates.list_to_nl_list(list(remaining_vars)), + ) + return templates.SINGLE_CORREL_SET_MAIN.format( + example_1=example_1, + optional_example_2=example_2, + optional_concluding_statement=concluding_statement, + ) + + def render_only_ind(self, correl_sets: List[Set[str]]) -> str: + """ + Describes a causal graph where we have at least two correlation + sets, all of which have only one variable, i.e. each variable + in the causal graph is independent of all other variables. The + description looks like: + ``` + In general, no discernible patterns were noticed between the variables. + For example, changes in {var_1} were not observed to reflect any changes in + {var_2}. + {optional example of other pair} + {optional concluding statement that this holds for all pairs} + ``` + """ + variables = [var for correl_set in correl_sets for var in correl_set] + num_vars = len(variables) # equal to the number of sets + # there's always at least 2 variables. + example_1 = templates.IND_VARS_EXAMPLE.format( + optional_transition="", + var_1=variables[0], + var_2=variables[1], + ) + example_2 = "" + concluding_statement = "" + if num_vars > 2: + example_2 = templates.IND_VARS_EXAMPLE.format( + optional_transition="Similarly, ", + var_1=variables[0], + var_2=variables[2], + ) + if num_vars > 3: + concluding_statement = templates.SPECIFIC_CONCL_STATEMENT.format( + already_mentioned=templates.list_to_nl_list(variables[:3]), + remaining_vars=templates.list_to_nl_list(variables[3:]), + ) + else: + concluding_statement = templates.GENERIC_CONCL_STATEMENT + + return templates.ONLY_IND_MAIN.format( + example_1=example_1, + optional_example_2=example_2, + optional_concluding_statement=concluding_statement, + ) + + def mention_unobserved_vars(self, sample: Sample) -> str: + """ + Mentions any unobserved variables that also hypothesized about. + """ + vars_to_mention = self._get_hypd_unobserved_vars(sample) + + n_vars_to_mention = len(vars_to_mention) + if n_vars_to_mention == 0: + return_string = "" + else: + be_plurality = {"singular": "is", "plural": "are"} + be_string = be_plurality["plural" if n_vars_to_mention > 1 else "singular"] + return_string = templates.UNOBS_BUT_HYP_VARS.format( + unobs_but_hyp_vars=templates.list_to_nl_list(vars_to_mention), + be_string=be_string, + ) + return return_string + + +if __name__ == "__main__": + import random + import numpy as np + + list_of_lists = [ + [{"x_1004"}, {"x_1005", "x_1006", "x_1007", "x_1008", "x_1009"}], + [{"x_1007", "x_1008", "x_1009"}, {"x_1010"}], + [{"x_1011"}, {"x_1012", "x_1013"}, {"x_1014"}], # 3 elements + [{"x_1022"}, {"x_1023", "x_1024"}, {"x_1025", "x_1026"}], + [{"x_1030"}, {"x_1031", "x_1032", "x_1033"}, {"x_1034"}, {"x_1035"}], + ] + + np_rng = np.random.default_rng(0) + renderer = PureCorrSetRenderer(random.Random(0), np_rng) + + from evals.elsuite.identifying_variables.scripts.gen_data import gen_samples + import networkx as nx + from pprint import pprint + + samples = gen_samples(10, None, np_rng) + + for sample in samples: + print("causal graph", nx.to_dict_of_lists(sample.causal_graph)) + print("hypotheses", list(sample.hypotheses.edges)) + pprint(sample.variable_metadata) + print(renderer.render_obs(sample)) + print("================") diff --git a/evals/elsuite/identifying_variables/renderers/tabular.py b/evals/elsuite/identifying_variables/renderers/tabular.py new file mode 100644 index 0000000000..0feb8b38fe --- /dev/null +++ b/evals/elsuite/identifying_variables/renderers/tabular.py @@ -0,0 +1,200 @@ +from typing import Optional, Tuple, Union, List +import json +import random + +import networkx as nx +import numpy as np +import pandas as pd + +from evals.elsuite.identifying_variables.structs import Sample +from evals.elsuite.identifying_variables.renderers.base import RendererBase +from evals.elsuite.identifying_variables.latent_funcs import ( + DISTRIBUTIONS, + LATENT_FUNC_MAP, +) +from evals.elsuite.identifying_variables.constants import NUM_OBS + + +def apply_noise( + data_df: pd.DataFrame, np_rng: np.random.Generator, snr: Optional[float] = None +) -> pd.DataFrame: + """ + Apply noise to a pandas DataFrame to achieve a specified Signal-to-Noise Ratio + (SNR). + + Args: + data_df (pd.DataFrame): The DataFrame containing the original data. + snr (float): The desired Signal-to-Noise Ratio in decibels (dB). + If None, no noise is applied. + """ + if snr is None: + return data_df + + desired_snr_linear = 10 ** (snr / 10) + + signal_powers = data_df.var() + noise_powers = signal_powers / desired_snr_linear + + noise = pd.DataFrame( + np_rng.normal(0, np.sqrt(noise_powers), data_df.shape), + columns=data_df.columns, + ) + noisy_df = data_df + noise + + return noisy_df + + +def sparsify_data( + data_df: pd.DataFrame, variable_metadata: dict, np_rng: np.random.Generator +) -> pd.DataFrame: + total_obs = data_df.shape[0] + for var in variable_metadata.keys(): + sparsity_rate = variable_metadata[var]["extra"]["sparsity_rate"] + num_missing_obs = int(sparsity_rate * total_obs) + missing_obs_indices = np_rng.choice(total_obs, num_missing_obs, replace=False) + data_df.loc[missing_obs_indices, var] = np.nan + return data_df + + +class TabularRenderer(RendererBase): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.num_obs = NUM_OBS + + def _render_table(self, sample: Sample) -> pd.DataFrame: + variable_metadata = sample.variable_metadata + sample_metadata = sample.sample_metadata + n_obs_samples = self.num_obs + causal_graph = sample.causal_graph + + # "topological sort" from least to most ancestors (i.e. least to most dependent) + sorted_vars = nx.topological_sort(causal_graph) + # necessary so that we can generate data in the correct order + + data_dict = {} + for var in sorted_vars: + gen_method = variable_metadata[var]["gen_method"]["name"] + if "input_x" not in variable_metadata[var]["gen_method"]: + distr = DISTRIBUTIONS[gen_method] + distr_kwargs = variable_metadata[var]["gen_method"]["kwargs"] + data_dict[var] = distr( + num_samples=n_obs_samples, **distr_kwargs, rng=self.np_rng + ) + else: + latent_func = LATENT_FUNC_MAP[gen_method] + latent_func_kwargs = variable_metadata[var]["gen_method"]["kwargs"] + input_x = variable_metadata[var]["gen_method"]["input_x"] + data_dict[var] = latent_func(x=data_dict[input_x], **latent_func_kwargs) + + data_df = pd.DataFrame(data_dict) + + # apply noise after generating data + data_df = apply_noise(data_df, self.np_rng, sample_metadata["snr"]) + # apply sparsification after generating and noise + data_df = sparsify_data(data_df, variable_metadata, self.np_rng) + + # round to 3 decimal places + data_df = data_df.round(3) + + return data_df + + +class MarkdownTableRenderer(TabularRenderer): + """ + Renders tabular data as a markdown table with variable names as column names. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def render_obs(self, sample: Sample) -> str: + data_df = self._render_table(sample) + return data_df.to_markdown(index=False) + + +class CSVTableRenderer(TabularRenderer): + """ + Renders tabular data as a comma-separated-values (CSV) file with variable names as + column names. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def render_obs(self, sample: Sample) -> str: + data_df = self._render_table(sample) + return data_df.to_csv(index=False) + + +class JSONTableRenderer(TabularRenderer): + """ + Renders tabular data as a JSON object with variable names as keys and lists of + values as values. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def render_obs(self, sample: Sample) -> str: + data_df = self._render_table(sample) + return json.dumps(data_df.to_dict(orient="list")) + + +class LanguageTableRenderer(TabularRenderer): + """ + Renders tabular data as a natural language description of the data. + Describing the data row by row. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.num_obs = 10 # set it to 10 + # realistically no one would read more than 10 rows of data one by one + + def render_obs(self, sample: Sample) -> str: + data_df = self._render_table(sample) + variables = list(data_df.columns) + rendered_obs = "" + current_step = "first" + for row in data_df.itertuples(index=False, name=None): + rendered_obs += self._render_row(row, variables, current_step) + "\n" + current_step = "next" + return rendered_obs + + def _render_row( + self, row: Tuple[Union[int, float]], variables: List[str], current_step: str + ) -> str: + string = f"On the {current_step} step, " + past_participle_verb = self.rng.choice(["measured", "recorded", "reported"]) + for value, var in zip(row, variables): + if np.isnan(value): + string += f"{var} was not {past_participle_verb}. " + else: + string += ( + f"{var} was {past_participle_verb} to be {format_number(value)}. " + ) + return string + + +def format_number(number: Union[int, float]): + """Get's rid of trailing .0's""" + if float(number).is_integer(): + return int(number) + else: + return number + + +if __name__ == "__main__": + # just for quick testing + np_rng = np.random.default_rng(0) + renderer = LanguageTableRenderer(random.Random(0), np_rng) + + from evals.elsuite.identifying_variables.scripts.gen_data import gen_samples + + samples = gen_samples(10, None, np_rng) + + for sample in samples: + print(nx.to_dict_of_lists(sample.causal_graph)) + print(sample.variable_metadata) + print(renderer.render_obs(sample)) + print("================") diff --git a/evals/elsuite/identifying_variables/renderers/templates.py b/evals/elsuite/identifying_variables/renderers/templates.py new file mode 100644 index 0000000000..c7a9000072 --- /dev/null +++ b/evals/elsuite/identifying_variables/renderers/templates.py @@ -0,0 +1,56 @@ +from typing import List + + +def list_to_nl_list(list_of_words: List[str]) -> str: + """ + Converts a list of words into a natural language list. + """ + if len(list_of_words) == 1: + return list_of_words[0] + elif len(list_of_words) == 2: + return f"{list_of_words[0]} and {list_of_words[1]}" + else: + return f"{', '.join(list_of_words[:-1])} and {list_of_words[-1]}" + + +OPENING_STATEMENT = """\ +The following is a description of the observations made about a set of variables. +""".strip() + +MANY_CORREL_SETS_MAIN = """\ +In general, there were cases where some variables changed in tandem with each other,\ + while others did not. +""".strip() + +SINGLE_CORREL_SET_MAIN = """\ +In general, all of the variables seemed to change in tandem with each other. +For example, {example_1} {optional_example_2} {optional_concluding_statement} +""".strip() + +ONLY_IND_MAIN = """\ +In general, no discernible patterns were noticed between the variables. +For example, {example_1} {optional_example_2} {optional_concluding_statement} +""".strip() + +CORREL_VARS_EXAMPLE = """\ +{optional_transition}changes in {var_1} were observed to reflect changes in {var_2} and\ + viceversa. +""".strip() + +IND_VARS_EXAMPLE = """\ +{optional_transition}changes in {var_1} were not observed to reflect any changes in\ + {var_2}. +""".strip() + +SPECIFIC_CONCL_STATEMENT = """\ +Similar observations were made for all other pairings within and across\ + {already_mentioned} and {remaining_vars}. +""".strip() + +GENERIC_CONCL_STATEMENT = """\ +Similar observations were made for all other pairings of the observed variables. +""".strip() + +UNOBS_BUT_HYP_VARS = """\ +{unobs_but_hyp_vars} {be_string} not observed but {be_string} hypothesized about. +""".strip() diff --git a/evals/elsuite/identifying_variables/scripts/data.sh b/evals/elsuite/identifying_variables/scripts/data.sh new file mode 100755 index 0000000000..418ebe3fef --- /dev/null +++ b/evals/elsuite/identifying_variables/scripts/data.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +# generate datasets of size 500 and 5000 +echo "Generating default dataset: 500 samples" +python gen_data.py --n_samples 500 --jsonl_dir ../../../registry/data/identifying_variables/ +echo "Generating large dataset: 5000 samples" +python gen_data.py --n_samples 5000 --jsonl_dir ../../../registry/data/identifying_variables/ +echo "Generating default dataset: 500 samples (balanced ctrl vars)" +python gen_data.py --balanced_ctrl_vars --n_samples 500 --jsonl_dir ../../../registry/data/identifying_variables/ +echo "Generating large dataset: 5000 samples (balanced ctrl vars)" +python gen_data.py --balanced_ctrl_vars --n_samples 5000 --jsonl_dir ../../../registry/data/identifying_variables/ + +echo "Done." diff --git a/evals/elsuite/identifying_variables/scripts/gen_data.py b/evals/elsuite/identifying_variables/scripts/gen_data.py new file mode 100644 index 0000000000..14c5f78e28 --- /dev/null +++ b/evals/elsuite/identifying_variables/scripts/gen_data.py @@ -0,0 +1,467 @@ +""" +Code for generating .jsonl dataset for identifying variables eval + +Use default argparse args to replicate the dataset used for the report +""" + +from dataclasses import asdict +import os +import argparse +from typing import Dict, List, Optional, Set, Tuple, Any +import json +import copy + +from tqdm.auto import tqdm +import networkx as nx +import numpy as np + +import evals.elsuite.identifying_variables.latent_funcs as latent_funcs +from evals.elsuite.identifying_variables.graph_utils import ( + gen_random_forest, + gen_random_forest_tree_size, + find_graph_roots, + find_unconnected_nodes_pair, + find_connected_nodes_pair, +) +from evals.elsuite.identifying_variables.utils import sample_serializer +from evals.elsuite.identifying_variables.structs import Sample, Answer +import evals.elsuite.identifying_variables.constants as constants + + +def write_to_jsonl( + samples: List[Sample], + jsonl_path: str, +): + with open(jsonl_path, "w") as f: + for sample in samples: + f.write(json.dumps(asdict(sample), default=sample_serializer) + "\n") + + +def random_latent_func_meta( + np_rng: np.random.Generator, input_x: Optional[str] = None +) -> Dict: + """ + Generates random metadata for defining a latent function + + Args: + input_x (Optional[str]): Name of input variable. If None, then + the latent function is a distribution, not dependent on any input. + """ + if input_x is None: + latent_func_name = np_rng.choice(list(latent_funcs.DISTRIBUTIONS.keys())) + predefined_kwargs = latent_funcs.DISTRIBUTIONS_KWARG_MAP[latent_func_name] + kwargs = {**predefined_kwargs} + return {"name": latent_func_name, "kwargs": kwargs} + else: + latent_func_name = np_rng.choice(list(latent_funcs.LATENT_FUNC_MAP.keys())) + predefined_kwargs = latent_funcs.LATENT_FUNC_KWARG_MAP[latent_func_name] + kwargs = {} + for kwarg, min_max in predefined_kwargs.items(): + kwarg_value = np_rng.integers(min_max["min_v"], min_max["max_v"]) + while kwarg == "grad" and kwarg_value == 0: + # dont allow 0 gradient + kwarg_value = np_rng.integers(min_max["min_v"], min_max["max_v"]) + kwargs[kwarg] = kwarg_value + return {"name": latent_func_name, "input_x": input_x, "kwargs": kwargs} + + +def build_var_metadata( + causal_graph: nx.DiGraph, + sparse_var_rate: float, + np_rng: np.random.Generator, +) -> Dict: + """ + Builds the variable metadata for a sample, containing + information on how each variable is generated and which variables + it is correlated with. + + Args: + causal_graph (nx.DiGraph): Causal graph of the sample. + sparse_var_rate (float): Percentage of variables that should be sparsified. + max_sparsity (float): Maximum sparsity rate for sparse variables. + np_rng (np.random.Generator): Random number generator to be used. + """ + var_metadata = {} + + roots = find_graph_roots(causal_graph) + root_to_descendants = {r: nx.descendants(causal_graph, r) for r in roots} + node_to_root = { + n: root + for root, descendants in root_to_descendants.items() + for n in descendants + } + + for var in causal_graph: + if var in roots: + latent_func_meta = random_latent_func_meta(np_rng, input_x=None) + var_root = var + else: + parent = next(causal_graph.predecessors(var)) + latent_func_meta = random_latent_func_meta(np_rng, input_x=parent) + var_root = node_to_root[var] + # variables with a common root are correlated. Need to copy to avoid mutation + corrs: Set[str] = set(root_to_descendants[var_root]) + if var_root != var: + # remove self-correlation, add correlation to root itself + corrs.remove(var) + corrs.add(var_root) + + var_metadata[var] = { + "gen_method": latent_func_meta, + "corrs": corrs, + "extra": {"sparsity_rate": 0}, + } + + # add sparsity + var_metadata = sparsify_data(var_metadata, sparse_var_rate, np_rng) + + return var_metadata + + +def sparsify_data(var_metadata, sparse_var_rate, np_rng): + num_observed_vars = 0 + orig_var_metadata = copy.deepcopy(var_metadata) + for var in var_metadata.keys(): + if np_rng.uniform(0, 1) < sparse_var_rate: + sparsity_rate = np_rng.uniform( + low=constants.MIN_SPARSITY, high=constants.MAX_SPARSITY + ) + var_metadata[var]["extra"]["sparsity_rate"] = sparsity_rate + if sparsity_rate > constants.SPARSITY_FOR_UNOBS: + # remove unobserved variables from correlations + for corr_var in var_metadata[var]["corrs"]: + var_metadata[corr_var]["corrs"].remove(var) + var_metadata[var]["corrs"] = set() + else: + num_observed_vars += 1 + else: + num_observed_vars += 1 + + # if less than 2 observed variables, sparsification was too much, try again + if num_observed_vars < 2: + var_metadata = sparsify_data(orig_var_metadata, sparse_var_rate, np_rng) + + return var_metadata + + +def gen_sample_balanced_ctrl_vars( + signal_noise_ratio: Optional[float], np_rng: np.random.Generator +) -> Sample: + """ + Generates a sample for the dataset, containing information on how a set + of variables are interlinked, and which hypotheses are currently held. + + This differs from gen_sample in the following ways: + + To simplify: + - The total number of variables in a given sample is fixed to MAX_VARS + - The hypothesis is always valid + + The number of control variables is sampled uniformly between 0 and MAX_VARS-2 + (we subtract 2 since two variables are involved in the hypothesis) + """ + sample_metadata = {"snr": signal_noise_ratio} + + n_vars = constants.MAX_VARS + + sparse_var_rate = np_rng.uniform( + low=constants.MIN_SPARSE_VAR_RATE, high=constants.MAX_SPARSE_VAR_RATE + ) # perc of variables to sparsify + + var_ids = np_rng.choice(np.arange(1000, 10000), size=n_vars, replace=False).astype( + str + ) + var_names = [f"x_{var_id}" for var_id in var_ids] + + num_ctrl_vars = np_rng.integers(low=0, high=n_vars - 1) # high is exclusive + + causal_graph = gen_random_forest_tree_size( + nodes=var_names, tree_size=num_ctrl_vars + 2, np_rng=np_rng + ) + + variable_metadata = build_var_metadata(causal_graph, sparse_var_rate, np_rng) + + target_hypothesis = find_connected_nodes_pair(causal_graph, np_rng) + target_hyp_is_valid = ( + parse_target_hyp(target_hypothesis, variable_metadata)[0] + if target_hypothesis is not None + else None + ) + # try again if the sparsification caused the hypothesis to be invalid + if target_hypothesis is None or not target_hyp_is_valid: + return gen_sample_balanced_ctrl_vars(signal_noise_ratio, np_rng) + + n_hypotheses = np_rng.integers( + low=constants.MIN_HYPS, + high=min(constants.MAX_HYPS, n_vars - 1) + 1, + ) + hypotheses = gen_random_forest(var_names, total_edges=n_hypotheses, np_rng=np_rng) + + hypotheses = integrate_target_hyp(target_hypothesis, hypotheses, np_rng) + + gold_label, num_not_ctrl = determine_gold_label( + target_hypothesis, variable_metadata, hypotheses + ) + + return Sample( + variable_metadata=variable_metadata, + hypotheses=hypotheses, + target_hypothesis=target_hypothesis, + sample_metadata=sample_metadata, + # keep track of underlying ground truth in case want more in depth analysis + causal_graph=causal_graph, + gold_label=gold_label, + num_not_ctrl=num_not_ctrl, + ) + + +def gen_sample( + signal_noise_ratio: Optional[float], + np_rng: np.random.Generator, + valid_hyp_requested: Optional[bool] = None, +) -> Sample: + """ + Generates a sample for the dataset, containing information on how a set + of variables are interlinked, and which hypotheses are currently held. + + Args: + signal_noise_ratio (float): Signal-to-noise ratio to be applied to the + observations. If None, no noise is applied. + np_rng (np.random.Generator): Random number generator to be used. + valid_hyp_requested (Optional[bool]): Whether the target hypothesis should be + valid. If None, will be randomly chosen. + + Returns: + Sample: A sample as defined by the `Sample` dataclass. + """ + sample_metadata = {"snr": signal_noise_ratio} + + n_vars = np_rng.integers(low=constants.MIN_VARS, high=constants.MAX_VARS + 1) + sparse_var_rate = np_rng.uniform( + low=constants.MIN_SPARSE_VAR_RATE, high=constants.MAX_SPARSE_VAR_RATE + ) # perc of variables to sparsify + + var_ids = np_rng.choice(np.arange(1000, 10000), size=n_vars, replace=False).astype( + str + ) + var_names = [f"x_{var_id}" for var_id in var_ids] + + causal_graph = gen_random_forest(var_names, np_rng=np_rng) + + variable_metadata = build_var_metadata(causal_graph, sparse_var_rate, np_rng) + + n_hypotheses = np_rng.integers( + low=constants.MIN_HYPS, + high=min(constants.MAX_HYPS, n_vars - 1) + 1, + ) + hypotheses = gen_random_forest(var_names, total_edges=n_hypotheses, np_rng=np_rng) + + if valid_hyp_requested is None: + # 0.5 chance of valid hypothesis + valid_hyp_requested = np_rng.uniform(0, 1) < 0.5 + + if valid_hyp_requested: + target_hypothesis = find_connected_nodes_pair(causal_graph, np_rng) + else: + target_hypothesis = find_unconnected_nodes_pair(causal_graph) + + target_hyp_is_valid = ( + parse_target_hyp(target_hypothesis, variable_metadata)[0] + if target_hypothesis is not None + else None + ) + if target_hypothesis is None or target_hyp_is_valid != valid_hyp_requested: + return gen_sample(signal_noise_ratio, np_rng, valid_hyp_requested) + + hypotheses = integrate_target_hyp(target_hypothesis, hypotheses, np_rng) + + gold_label, num_not_ctrl = determine_gold_label( + target_hypothesis, variable_metadata, hypotheses + ) + + return Sample( + variable_metadata=variable_metadata, + hypotheses=hypotheses, + target_hypothesis=target_hypothesis, + sample_metadata=sample_metadata, + # keep track of underlying ground truth in case want more in depth analysis + causal_graph=causal_graph, + gold_label=gold_label, + num_not_ctrl=num_not_ctrl, + ) + + +def determine_gold_label( + target_hyp, variable_metadata, hypotheses +) -> Tuple[Answer, Optional[int]]: + """ + Determines the ideal `Answer` for a given sample. Additionally returns + the number of variables not controlled for, if the hypothesis is valid, + necessary for nDCG calculation. + """ + valid_hypothesis, ind_var, dep_var = parse_target_hyp(target_hyp, variable_metadata) + if not valid_hypothesis: + ctrl_vars, not_ctrls = None, None + num_not_ctrl = None + else: + ctrl_vars, not_ctrls = determine_ctrl_vars( + variable_metadata, ind_var, dep_var, hypotheses + ) + # worst case ctrl: all vars that aren't meant to be ctrld are ctrld + num_not_ctrl = len(not_ctrls) + + return ( + Answer( + valid_hypothesis=valid_hypothesis, + ind_var=ind_var, + dep_var=dep_var, + ctrl_vars=ctrl_vars, + ), + num_not_ctrl, + ) + + +def parse_target_hyp( + target_hyp: Tuple[str, str], variable_metadata: Dict[str, Any] +) -> Tuple[bool, Optional[str], Optional[str]]: + """Implements decision tree in Figure 2 from eval spec""" + proposed_ind = target_hyp[0] + proposed_dep = target_hyp[1] + + ind_unobserved = ( + variable_metadata[proposed_ind]["extra"]["sparsity_rate"] + > constants.SPARSITY_FOR_UNOBS + ) + dep_unobserved = ( + variable_metadata[proposed_dep]["extra"]["sparsity_rate"] + > constants.SPARSITY_FOR_UNOBS + ) + + # if either are unobserved, we have no evidence that they are not correlated + if ind_unobserved or dep_unobserved: + return True, proposed_ind, proposed_dep + # evidence of lack of correlation + elif proposed_dep not in variable_metadata[proposed_ind]["corrs"]: + return False, None, None + # evidence of correlation + else: + return True, proposed_ind, proposed_dep + + +def determine_ctrl_vars( + variable_metadata: Dict[str, Any], + ind_var: str, + dep_var: str, + hypotheses: nx.DiGraph, +) -> Tuple[List[str], List[str]]: + """Implements decision tree in Figure 1 from eval spec""" + ctrl_vars = [] + not_ctrls = [] + for var in variable_metadata: + if var in {ind_var, dep_var}: + not_ctrls.append(var) + elif are_correlated(var, dep_var, variable_metadata) or are_correlated( + var, ind_var, variable_metadata + ): + ctrl_vars.append(var) + elif are_correlated(var, dep_var, variable_metadata) is not None: + # don't control vars which we have observed to be uncorrelated w/ dep + not_ctrls.append(var) + else: # when dep_var or var is unobserved, no evidence of lack of correlation + # control for any var which might influence the dependent variable + dep_var_ancestors = nx.ancestors(hypotheses, dep_var) + if var in dep_var_ancestors: + ctrl_vars.append(var) + else: + not_ctrls.append(var) + + return ctrl_vars, not_ctrls + + +def are_correlated(var_1, var_2, variable_metadata) -> Optional[bool]: + """ + Returns whether two variables are correlated. If there is no evidence + of correlation, returns None. + """ + if ( + variable_metadata[var_1]["extra"]["sparsity_rate"] + > constants.SPARSITY_FOR_UNOBS + or variable_metadata[var_2]["extra"]["sparsity_rate"] + > constants.SPARSITY_FOR_UNOBS + ): + return None + return ( + var_2 in variable_metadata[var_1]["corrs"] + or var_1 in variable_metadata[var_2]["corrs"] + ) + + +def integrate_target_hyp( + target_hyp: Tuple[Any, Any], hyp_graph: nx.DiGraph, np_rng: np.random.Generator +): + """ + Integrates the target hypothesis into the hypotheses graph, respecting + the original edge count by removing a random edge if necessary. + """ + if not hyp_graph.has_edge(*target_hyp): + random_edge_to_remove = np_rng.choice(list(hyp_graph.edges)) + hyp_graph.remove_edge(*random_edge_to_remove) + hyp_graph.add_edge(*target_hyp) + return hyp_graph + + +def gen_samples( + n_samples: int, + signal_noise_ratio: Optional[float], + np_rng: np.random.Generator, + balanced_ctrl_vars: bool = False, +) -> List[Sample]: + samples = [] + if not balanced_ctrl_vars: + for _ in tqdm(range(n_samples)): + sample = gen_sample(signal_noise_ratio, np_rng) + samples.append(sample) + else: + for _ in tqdm(range(n_samples)): + sample = gen_sample_balanced_ctrl_vars(signal_noise_ratio, np_rng) + samples.append(sample) + + return samples + + +def main(args: argparse.Namespace): + np_rng = np.random.default_rng(args.seed) + samples = gen_samples(args.n_samples, args.snr, np_rng, args.balanced_ctrl_vars) + os.makedirs(args.jsonl_dir, exist_ok=True) + if not args.balanced_ctrl_vars: + jsonl_path = os.path.join(args.jsonl_dir, f"{args.n_samples}.jsonl") + else: + jsonl_path = os.path.join( + args.jsonl_dir, f"{args.n_samples}_balanced_ctrl_vars.jsonl" + ) + write_to_jsonl(samples, jsonl_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + + parser.add_argument("--n_samples", type=int, default=5000) + parser.add_argument( + "--snr", + type=float, + default=None, + help="signal-to-noise ratio. Default None (no noise is applied.)", + ) + parser.add_argument( + "--jsonl_dir", type=str, default="./evals/registry/data/identifying_variables/" + ) + parser.add_argument("--seed", type=int, default=20220722) + parser.add_argument( + "--balanced_ctrl_vars", + action="store_true", + help="Whether to generate samples with balanced control variables.", + default=False, + ) + args = parser.parse_args() + + main(args) diff --git a/evals/elsuite/identifying_variables/scripts/make_plots.py b/evals/elsuite/identifying_variables/scripts/make_plots.py new file mode 100644 index 0000000000..f29f781492 --- /dev/null +++ b/evals/elsuite/identifying_variables/scripts/make_plots.py @@ -0,0 +1,400 @@ +from pathlib import Path +from typing import Dict, Tuple + +import numpy as np +import pandas as pd +from tqdm.auto import tqdm + +from evals.elsuite.identifying_variables.metrics import compute_metric_posthoc +from evals.elsuite.identifying_variables.scripts.plotting_utils import ( + plot_difficulty_bars, + plot_solver_bars, +) +from evals.elsuite.identifying_variables.scripts.table_utils import ( + make_main_metric_table, +) +from evals.utils import log_utils + +NUM_REPEATS = 3 +MAIN_METRICS = [ + "ctrl_nDCG", + "ctrl_recall", + "hyp_valid_acc", + "ind_acc", + "dep_acc", + "violation_rate", +] + +SOLVERS = [ + "generation/direct/gpt-3.5-turbo", + "generation/cot/gpt-3.5-turbo", + "generation/hhh/gpt-4-base", + "generation/cot_hhh/gpt-4-base", + "generation/direct/gpt-4-1106-preview", + "generation/cot/gpt-4-1106-preview", + "generation/cot/mixtral-8x7b-instruct", + "generation/cot/llama-2-70b-chat", + "generation/cot/gemini-pro", + "identifying_variables/random", + "identifying_variables/noctrl", +] + + +RENDERERS = [ + "markdown", + "csv", + "json", + "language-tabular", + "language-corrset", + "corrset", +] + + +def initialize_default_results_dict(): + results_dict = { + metric: { + stat: { + solver: { + renderer: { + "with tree": ([] if stat == "raw" else 0), + "without tree": ([] if stat == "raw" else 0), + } + for renderer in RENDERERS + } + for solver in SOLVERS + } + for stat in ["raw", "mean", "sem"] + } + for metric in MAIN_METRICS + } + return results_dict + + +def handle_cot_double_sampling(sampling_entries, solver): + if "cot" in solver: + sampling_entries = [ + entry + for entry in sampling_entries + if ( + # for chat models we filter like this + isinstance(entry["prompt"], list) + and entry["prompt"][-1]["content"].startswith( + "Given the above reasoning" + ) + or ( + # for base models we need to filter like this + isinstance(entry["prompt"], str) + and "Given the above reasoning" in entry["prompt"] + ) + ) + ] + return sampling_entries + + +def handle_posthoc_metrics(final_results: Dict, log_path: Path, solver: str): + """ + Computes and includes missing metrics from log file if they are not present + """ + metric_entries = log_utils.extract_individual_results(log_path) + sampling_entries = log_utils.extract_individual_results(log_path, "sampling") + # filter out cot double samplings + sampling_entries = handle_cot_double_sampling(sampling_entries, solver) + # this is necessary because we originally didnt compute recall in the eval + for metric in MAIN_METRICS: + if metric not in final_results.keys(): + final_results[metric] = compute_metric_posthoc( + metric, metric_entries, sampling_entries + ) + + return final_results + + +def populate_default_results_dict(results_dict, results_dir): + for log in tqdm(results_dir.glob("*.log"), total=222): + spec = log_utils.extract_spec(log) + solver = spec["completion_fns"][0] + run_config = spec["run_config"] + renderer = run_config["eval_spec"]["args"]["renderer"] + show_tree = "show_tree=True" in run_config["command"] + tree_key = "with tree" if show_tree else "without tree" + if renderer not in RENDERERS and solver != "identifying_variables/random": + continue + if solver not in SOLVERS: + continue + + final_results = log_utils.extract_final_results(log) + final_results = handle_posthoc_metrics(final_results, log, solver) + + for metric, value in final_results.items(): + if metric in MAIN_METRICS: + results_dict[metric]["raw"][solver][renderer][tree_key].append(value) + raw = results_dict[metric]["raw"][solver][renderer][tree_key] + results_dict[metric]["mean"][solver][renderer][tree_key] = np.mean(raw) + results_dict[metric]["sem"][solver][renderer][tree_key] = np.std( + raw + ) / np.sqrt(NUM_REPEATS) + for metric in results_dict.keys(): + del results_dict[metric]["raw"] + return results_dict + + +def make_default_tables(results_dict: Dict, save_dir: Path): + for metric in tqdm(MAIN_METRICS): + make_main_metric_table(results_dict, metric, SOLVERS, RENDERERS, save_dir) + + +def extract_default_results_dict(results_dir: Path): + results_dict = initialize_default_results_dict() + results_dict = populate_default_results_dict(results_dict, results_dir) + + return results_dict + + +def make_default_plots(results_dict: Dict, save_dir: Path): + all_solvers = list(results_dict["ctrl_nDCG"]["mean"].keys()) + bar_solvers, baseline_solvers = all_solvers[:-2], all_solvers[-2:] + + metrics = ["ctrl_nDCG", "ctrl_recall"] + metric_labels = ["Control Variable Retrieval nDCG*", "Control Variable Recall"] + fig_heights = [6, 5] + + for metric, metric_label, fig_height in tqdm( + zip(metrics, metric_labels, fig_heights) + ): + plot_solver_bars( + bar_solvers, + baseline_solvers, + results_dict[metric], + metric_label, + fig_height, + save_dir / f"{metric}.png", + ) + + +def extract_large_results_dict(results_dir: Path) -> Dict: + ctrl_nDCG_bins = list(range(0, 9)) + results_dict = {} + for log in tqdm(results_dir.glob("*.log"), total=12): + spec = log_utils.extract_spec(log) + final_results = log_utils.extract_final_results(log) + solver = spec["completion_fns"][0] + renderer = spec["split"] + key = f"{solver};{renderer}" + if key not in results_dict: + results_dict[key] = { + bbin: {"raw": [], "mean": None, "sem": None} for bbin in ctrl_nDCG_bins + } + + for bbin in ctrl_nDCG_bins: + results_dict[key][bbin]["raw"].append( + final_results[f"ctrl_nDCG-n_ctrl_vars-{bbin}"] + ) + for key in results_dict.keys(): + for bbin in ctrl_nDCG_bins: + mean = np.mean(results_dict[key][bbin]["raw"]) + sem = np.std(results_dict[key][bbin]["raw"]) / 3 + results_dict[key][bbin]["mean"] = mean + results_dict[key][bbin]["sem"] = sem + del results_dict[key][bbin]["raw"] + + return results_dict + + +def make_large_plot(large_results_dir: Dict, save_dir: Path): + ctrl_vars_bins = list(range(0, 9)) + plot_difficulty_bars( + large_results_dir, ctrl_vars_bins, save_dir / "ctrl_nDCG_difficulty.png" + ) + + +def np_nan_if_none(input_num): + if input_num is None: + return np.nan + else: + return input_num + + +def zero_if_none(input_num): + if input_num is None: + return 0 + else: + return input_num + + +def round_if_not_nan(input_num): + if np.isnan(input_num): + return input_num + else: + return round(input_num) + + +def make_token_per_sample_df(solver_to_eval, solver_to_tokens) -> pd.DataFrame: + tokens_per_sample_df = pd.DataFrame( + index=solver_to_eval.keys(), + columns=[ + "input tokens/sample", + "output tokens/sample", + "total tokens/sample", + ], + ) + for solver in solver_to_tokens.keys(): + # print(solver_to_tokens[solver]) + input_mean = np.nanmean(solver_to_tokens[solver]["input"]) + output_mean = np.nanmean(solver_to_tokens[solver]["output"]) + total_mean = np.nanmean(solver_to_tokens[solver]["total"]) + # print([input_mean, output_mean, total_mean]) + tokens_per_sample_df.loc[solver] = [ + round_if_not_nan(input_mean), + round_if_not_nan(output_mean), + round_if_not_nan(total_mean), + ] + solver_to_index = { + "generation/hhh/gpt-4-base": "HHH GPT-4-base (corrset, no tree)", + "generation/direct/gpt-3.5-turbo": "Direct GPT-3.5-turbo (corrset, no tree)", + "generation/direct/gpt-4-1106-preview": "Direct GPT-4-1106-preview (corrset, no tree)", + "generation/cot_hhh/gpt-4-base": "CoT HHH GPT-4-base (language-tabular, with tree)", + "generation/cot/gpt-3.5-turbo": "CoT GPT-3.5-turbo (language-tabular, with tree)", + "generation/cot/gpt-4-1106-preview": "CoT GPT-4-1106-preview (language-tabular, with tree)", + } + tokens_per_sample_df = tokens_per_sample_df.rename(index=solver_to_index) + return tokens_per_sample_df + + +def count_tokens(results_dir: Path, total) -> Tuple[Dict, pd.DataFrame]: + eval_names = [ + "identifying_variables.corrset.default", + "identifying_variables.language-tabular.default", + ] + solver_names = [ + "generation/hhh/gpt-4-base", + "generation/direct/gpt-3.5-turbo", + "generation/direct/gpt-4-1106-preview", + "generation/cot_hhh/gpt-4-base", + "generation/cot/gpt-3.5-turbo", + "generation/cot/gpt-4-1106-preview", + ] + solver_to_eval = { + solver: eval_names[0] if "cot" not in solver else eval_names[1] + for solver in solver_names + } + solver_to_tree = { + solver: False if "cot" not in solver else True for solver in solver_names + } + solver_to_tokens = { + solver: {"input": [], "output": [], "total": []} for solver in solver_names + } + total_input = 0 + total_output = 0 + for log in tqdm(results_dir.glob("*.log"), total=total): + spec = log_utils.extract_spec(log) + solver = spec["completion_fns"][0] + if solver not in solver_names: + print(f"Skipping {solver}: token counting not supported.") + continue + eval_name = spec["eval_name"] + seed = spec["run_config"]["seed"] + tree = "show_tree=True" in spec["run_config"]["command"] + samplings = log_utils.extract_individual_results(log, "sampling") + samplings = handle_cot_double_sampling(samplings, solver) + for sampling in samplings: + usage = sampling["usage"] + if ( + solver in solver_to_eval + and eval_name == solver_to_eval[solver] + and seed == 1 + and tree != solver_to_tree[solver] + ): + solver_to_tokens[solver]["input"].append( + np_nan_if_none(usage["prompt_tokens"]) + ) + solver_to_tokens[solver]["output"].append( + np_nan_if_none(usage["completion_tokens"]) + ) + solver_to_tokens[solver]["total"].append( + np_nan_if_none(usage["total_tokens"]) + ) + total_input += zero_if_none(usage["prompt_tokens"]) + total_output += zero_if_none(usage["completion_tokens"]) + + total_tokens = {"input": total_input, "output": total_output} + tokens_per_sample_df = make_token_per_sample_df(solver_to_eval, solver_to_tokens) + + return total_tokens, tokens_per_sample_df + + +def make_total_tokens_table(default_total: Dict, large_total: Dict) -> pd.DataFrame: + """ + Makes a dataframe where the index is "default" "large" and the columns are + "input", "output"; showing the total number of input and output tokens for + our experiments on each dataset. + """ + total_tokens_df = pd.DataFrame( + { + "input": [default_total["input"], large_total["input"]], + "output": [default_total["output"], large_total["output"]], + }, + index=["default", "large"], + ) + return total_tokens_df + + +def make_token_count_tables( + default_results_dir: Path, large_results_dir: Path, save_dir: Path +): + default_total_tokens, default_per_sample_tokens_df = count_tokens( + default_results_dir, total=222 + ) + large_total_tokens, _ = count_tokens(large_results_dir, total=12) + + total_tokens_df = make_total_tokens_table(default_total_tokens, large_total_tokens) + + # save the tables + total_tokens_df.to_csv(save_dir / "total_tokens.csv") + default_per_sample_tokens_df.to_csv(save_dir / "per_sample_tokens.csv") + + +def main(default_results_dir: Path, large_results_dir: Path, save_dir: Path): + save_dir.mkdir(parents=True, exist_ok=True) + + print("Parsing default dataset results...") + default_results_dict = extract_default_results_dict(default_results_dir) + print("Making default dataset tables...") + make_default_tables(default_results_dict, save_dir) + print("Making default dataset plots...") + make_default_plots(default_results_dict, save_dir) + + print("Parsing large dataset results...") + large_results_dict = extract_large_results_dict(large_results_dir) + print("Making large dataset plot...") + make_large_plot(large_results_dict, save_dir) + + print("Making token count tables...") + make_token_count_tables(default_results_dir, large_results_dir, save_dir) + print("Done.") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Process results") + parser.add_argument( + "--default_results_dir", + type=str, + help="Path to directory containing .log files from experiments on default dataset", + ) + parser.add_argument( + "--large_results_dir", + type=str, + help="Path to directory containing .log files from experiments on large dataset", + ) + parser.add_argument( + "--save_dir", type=str, help="Path to directory to save plots and tables to" + ) + + args = parser.parse_args() + + main( + Path(args.default_results_dir), + Path(args.large_results_dir), + Path(args.save_dir), + ) diff --git a/evals/elsuite/identifying_variables/scripts/plotting_utils.py b/evals/elsuite/identifying_variables/scripts/plotting_utils.py new file mode 100644 index 0000000000..1c80aab042 --- /dev/null +++ b/evals/elsuite/identifying_variables/scripts/plotting_utils.py @@ -0,0 +1,163 @@ +from typing import Dict, Iterable, List +from pathlib import Path + +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + + +renderers_of_interest = ["csv", "language-corrset"] + +renderer_to_label = { + "csv": "CSV observations", + "language-corrset": "Correlation set", +} + +cmap = plt.get_cmap("Paired") +colors = np.array([cmap(i) for i in range(len(renderers_of_interest))]) +renderer_to_color = {r: c for r, c in zip(renderers_of_interest, colors)} + +solver_to_label = { + "generation/direct/gpt-3.5-turbo": "Direct gpt-3.5-turbo", + "generation/cot/gpt-3.5-turbo": "CoT gpt-3.5-turbo", + "generation/hhh/gpt-4-base": "HHH gpt-4-base", + "generation/cot_hhh/gpt-4-base": "CoT HHH gpt-4-base", + "generation/direct/gpt-4-1106-preview": "Direct gpt-4-1106-preview", + "generation/cot/gpt-4-1106-preview": "CoT gpt-4-1106-preview", + "generation/cot/mixtral-8x7b-instruct": "CoT mixtral-8x7b-instruct\n(Correlation set only)", + "generation/cot/llama-2-70b-chat": "CoT llama-2-70b-chat\n(Correlation set only)", + "generation/cot/gemini-pro": "CoT gemini-pro-1.0\n(Correlation set only)", + "identifying_variables/random": "Random baseline", + "identifying_variables/noctrl": "NoCtrl baseline", +} + +baseline_to_linestyle = { + "identifying_variables/random": "--", + "identifying_variables/noctrl": "-.", +} + +cmap = plt.get_cmap("Set2") +bline_colors = np.array( + [cmap(i) for i in range(0, len(baseline_to_linestyle.keys()) + 0)] +) +baseline_to_color = { + key: color for key, color in zip(baseline_to_linestyle.keys(), bline_colors) +} + + +def plot_solver_bars( + bar_solvers: List[str], + baseline_solvers: List[str], + metric_results: Dict, + metric_label: str, + fig_height: int, + output_path: Path, +): + """ + Plots a side-by-side bar plot of the metric results, showing the + solvers on the x axis and the metric value on the y axis. + + Args: + bar_solvers: The names of solvers to plot. + baseline_solvers: The names of the baseline solvers to plot. + metric_results: A dictionary with k: v of format solver : {mean: value, sem: value} + metric_label: The label for the y axis + fig_height: the height of the figure in inches + output_path: the path to save the figure to + """ + sns.set_context("paper") + sns.set_style("whitegrid") + + bar_width = 0.3 + positions = np.arange(len(bar_solvers)) + + f, ax = plt.subplots(1, 1, dpi=300, figsize=(9, fig_height)) + + for i, renderer in enumerate(renderers_of_interest): + bars = [ + metric_results["mean"][solver][renderer]["without tree"] + for solver in bar_solvers + ] + errors = [ + metric_results["sem"][solver][renderer]["without tree"] + for solver in bar_solvers + ] + + ax.bar( + positions + bar_width * i, + bars, + bar_width, + yerr=errors, + label=renderer_to_label[renderer], + color=renderer_to_color[renderer], + ) + + for baseline_solver in baseline_solvers: + mean = metric_results["mean"][baseline_solver]["corrset"]["without tree"] + sem = metric_results["sem"][baseline_solver]["corrset"]["without tree"] + ax.axhline( + mean, + label=solver_to_label[baseline_solver], + color=baseline_to_color[baseline_solver], + linestyle=baseline_to_linestyle[baseline_solver], + ) + ax.axhspan( + mean - sem, mean + sem, alpha=0.1, color=baseline_to_color[baseline_solver] + ) + + ax.set_xticks( + positions + bar_width / 2, + [solver_to_label[s] for s in bar_solvers], + rotation=45, + ha="right", + ) + ax.tick_params( + axis="x", which="both", bottom=True + ) # Show both major and minor xticks + ax.set_ylabel(metric_label) + ax.set_ylim(-0.005, 1) + ax.xaxis.grid(False) + ax.legend() + f.set_tight_layout(True) + plt.savefig(output_path, dpi=300, bbox_inches="tight") + + +def plot_difficulty_bars(results_dict: Dict, bins: Iterable[int], output_path: Path): + sns.set_context("paper") + sns.set_style("whitegrid") + + f, ax = plt.subplots(1, 1, dpi=300, figsize=(7, 4)) + + positions = np.arange(len(bins)) + bar_width = 0.4 + + for i, key in enumerate(sorted(results_dict.keys())): + solver, renderer = key.split(";") + bars = [results_dict[key][bbin]["mean"] for bbin in bins] + errors = [results_dict[key][bbin]["sem"] for bbin in bins] + if solver == "generation/direct/gpt-4-1106-preview": + label = renderer_to_label[renderer] + color = renderer_to_color[renderer] + ax.bar( + positions + bar_width * i, + bars, + bar_width, + yerr=errors, + label=label, + color=color, + ) + + ax.set_xlabel("Number of necessary control variables") + ax.set_ylabel("Control Variable Retrieval nDCG*") + + ax.set_xlim(-0.3, 8.7) + ax.set_ylim(0, 1) + ax.xaxis.grid(False) + ax.legend() + ax.set_xticks(positions + bar_width / 2, bins) + f.set_tight_layout(True) + plt.savefig( + output_path, + dpi=300, + bbox_inches="tight", + ) diff --git a/evals/elsuite/identifying_variables/scripts/run_experiments.sh b/evals/elsuite/identifying_variables/scripts/run_experiments.sh new file mode 100755 index 0000000000..fae5ceb93b --- /dev/null +++ b/evals/elsuite/identifying_variables/scripts/run_experiments.sh @@ -0,0 +1,105 @@ +#!/bin/bash + +# Function to display usage +usage() { + echo "Usage: $0 -s size -l logdir" + echo " -s size Specify the size of the experiments (options: 'balanced-hypotheses', 'balanced-ctrl', 'balanced-hypotheses-large', 'balanced-ctrl-large')" + echo " -l logdir Specify the directory for log files" + exit 1 +} + +# Check if no arguments were provided +if [ $# -eq 0 ]; then + usage + exit 1 +fi + +# Parse command-line options +while getopts 's:l:' flag; do + case "${flag}" in + s) size=${OPTARG} ;; + l) logdir=${OPTARG} ;; + *) usage ;; + esac +done + +# Check if mandatory arguments were provided +if [ -z "$size" ] || [ -z "$logdir" ]; then + usage + exit 1 +fi + +logdirbase=$logdir +NUM_REPEATS=3 + +# Function to run experiments +run_experiments() { + local size=$1 + local logpathbase="${logdirbase}/${size}" + local start_time=$SECONDS + + # Define RENDERERS and SOLVERS array based on size + declare -a RENDERERS + declare -a SOLVERS + if [ "$size" == "balanced-hypotheses" ]; then + RENDERERS=("markdown" "csv" "json" "language-tabular" "language-corrset" "corrset") + SOLVERS=("generation/direct/gpt-3.5-turbo" + "generation/cot/gpt-3.5-turbo" + "generation/hhh/gpt-4-base" + "generation/cot_hhh/gpt-4-base" + "generation/direct/gpt-4-1106-preview" + "generation/cot/gpt-4-1106-preview") + elif [ "$size" == "balanced-ctrl" ]; then + RENDERERS=("csv" "language-corrset") + SOLVERS=("generation/direct/gpt-3.5-turbo" + "generation/cot/gpt-3.5-turbo" + "generation/hhh/gpt-4-base" + "generation/cot_hhh/gpt-4-base" + "generation/direct/gpt-4-1106-preview" + "generation/cot/gpt-4-1106-preview") + else + RENDERERS=("csv" "language-corrset") + SOLVERS=("generation/direct/gpt-4-1106-preview") + fi + + # Main loop + for ((i = 1; i <= NUM_REPEATS; i++)); do + for solver in "${SOLVERS[@]}"; do + for renderer in "${RENDERERS[@]}"; do + run_solver $solver $renderer $size $i "$logpathbase" + done + done + run_solver "identifying_variables/random" "corrset" $size $i "$logpathbase" + run_solver "identifying_variables/noctrl" "corrset" $size $i "$logpathbase" + done + + local end_time=$SECONDS + echo "Done running experiments for $size size, all logs in $logpathbase" + echo "Total execution time: $((end_time - start_time)) seconds." +} + +# Function to run a single solver +run_solver() { + local solver=$1 + local renderer=$2 + local size=$3 + local seed=$4 + local logpathbase=$5 + local solver_dotted=${solver//\//.} + + local record_path="${logpathbase}/${solver_dotted}_${renderer}_${size}_${seed}" + echo "Running $solver with $renderer renderer and $size data size; seed $seed" + + local sub_start_time=$(date +%s) + oaieval "$solver" "identifying_variables.${renderer}.${size}" --record_path "$record_path.log" --seed $seed + local sub_end_time=$(date +%s) + echo "${solver_dotted}_${renderer}_${size} execution time: $((sub_end_time - sub_start_time)) seconds." + + skip_tree_solvers=("identifying_variables/random" "identifying_variables/noctrl") + if [[ ! "${skip_tree_solvers[@]}" =~ "$solver" ]] && [ "$size" == "balanced-hypotheses" ]; then + echo "Now repeating with show_tree=True" + oaieval "$solver" "identifying_variables.${renderer}.${size}" --extra_eval_params show_tree=True --record_path "${record_path}_tree.log" --seed $seed + fi +} + +run_experiments "${size}" diff --git a/evals/elsuite/identifying_variables/scripts/table_utils.py b/evals/elsuite/identifying_variables/scripts/table_utils.py new file mode 100644 index 0000000000..3991cd469b --- /dev/null +++ b/evals/elsuite/identifying_variables/scripts/table_utils.py @@ -0,0 +1,66 @@ +from typing import Dict, List +from pathlib import Path + +import numpy as np +import pandas as pd + + +def make_main_metric_table( + results_dict: Dict, + metric: str, + solvers: List[str], + renderers: List[str], + save_dir: Path, +): + """ + Makes and saves a table containing the information of performance of + each solver for each renderer for each variant of the eval on + a given metric. + - Table rows are solvers; they are multi-rows, so each row has two subrows: with + tree and without tree + - Table columns are renderers; they are multi-columns, so each column has two + subcolumns: mean and sem (standard error of the mean) + + Args: + results_dict: dictionary containing the results of the eval. See + `initialize_default_results_dict` and `populate_default_results_dict` in + `process_results.py`. + metric: the name of the metric we want to make the table for + solvers: list of solvers we want to include in the table + renderers: list of renderers we want to include in the table + save_dir: directory to save the table in (as a CSV file) + """ + + # only keep keep metric in results_dict + filtered_results_dict = results_dict[metric] + # flatten into tuples + data_tuples = [] + for stat, solver_data in filtered_results_dict.items(): + for solver, renderer_data in solver_data.items(): + for renderer, tree_data in renderer_data.items(): + for tree_type, value in tree_data.items(): + if value is not None: + data_tuples.append((solver, tree_type, renderer, stat, value)) + + df = pd.DataFrame( + data_tuples, columns=["Solver", "Tree", "Renderer", "Stat", "Value"] + ) + df = df.pivot_table( + index=["Solver", "Tree"], columns=["Renderer", "Stat"], values="Value" + ) + # sorting by solvers, renderers (for some reason ordering is lost in the above process) + new_index = [ + (solver, tree) for solver in solvers for tree in ["with tree", "without tree"] + ] + new_columns = pd.MultiIndex.from_product( + [renderers, df.columns.levels[1]], names=df.columns.names + ) + df = df.reindex(new_index, columns=new_columns) + + # delete the with tree rows for the treeless solvers + for solver in solvers[-2:]: + df.drop((solver, "with tree"), inplace=True) + + # save table + save_path = save_dir / f"{metric}_table.csv" + df.to_csv(save_path) diff --git a/evals/elsuite/identifying_variables/solvers.py b/evals/elsuite/identifying_variables/solvers.py new file mode 100644 index 0000000000..c6010c74da --- /dev/null +++ b/evals/elsuite/identifying_variables/solvers.py @@ -0,0 +1,48 @@ +import random + +from evals.solvers.solver import Solver, SolverResult +from evals.task_state import TaskState + + +class RandomSolver(Solver): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _solve(self, task_state: TaskState) -> SolverResult: + valid_hyp = random.uniform(0, 1) < 0.5 + + variables = task_state.current_state["variables"] + n_vars_to_sample = random.randint(2, len(variables)) + ind_var, dep_var, *ctrl_vars = random.sample(variables, n_vars_to_sample) + if len(ctrl_vars) == 0: + ctrl_vars = "none" + else: + ctrl_vars = ", ".join(ctrl_vars) + + solver_string = f"[@ANSWER valid_hyp: {valid_hyp}; independent: {ind_var}; dependent: {dep_var}; control: {ctrl_vars}]" + + return SolverResult(output=solver_string) + + +class NoCtrl(Solver): + """ + Solver that always returns no control variables + (i.e. "none", interpreted as an empty list by the eval) + what it returns for the other variables is arbitrary + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _solve(self, task_state: TaskState) -> SolverResult: + # we don't care about valid_hyp and ind/dep vars for this solver + # it's only used for the ctrl variables subtask + valid_hyp = True + variables = task_state.current_state["variables"] + ind_var, dep_var = random.sample(variables, 2) + + # it just always returns no control variables + ctrl_vars = "none" + solver_string = f"[@ANSWER valid_hyp: {valid_hyp}; independent: {ind_var}; dependent: {dep_var}; control: {ctrl_vars}]" + + return SolverResult(output=solver_string) diff --git a/evals/elsuite/identifying_variables/structs.py b/evals/elsuite/identifying_variables/structs.py new file mode 100644 index 0000000000..90b47b96b0 --- /dev/null +++ b/evals/elsuite/identifying_variables/structs.py @@ -0,0 +1,49 @@ +"""Custom data structures for the eval""" +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import networkx as nx + + +@dataclass +class Answer: + valid_hypothesis: bool + ind_var: Optional[str] + dep_var: Optional[str] + ctrl_vars: Optional[List[str]] + + +@dataclass +class Sample: + """ + A sample of the dataset for the eval. + + Args: + variable_metadata (Dict) : A dictionary mapping each variable name to its metadata. + Each variable's metadata is a dictionary containing: + - 'gen_method': A dictionary specifying the generation method for the + variable, including: + - 'name': Name of the latent function or distribution. + - 'input_x': Name of the input variable, if applicable. + - 'kwargs': Additional arguments for the latent function. + - 'corrs': A set of variables correlated with this variable. + hypotheses (nx.DiGraph): A directed acyclic graph (DAG) representing the hypotheses. + target_hypothesis (Tuple[str, str]) A tuple (independent_variable, dependent_variable) + representing the hypothesis of interest. + sample_metadata (Dict): A dictionary with additional metadata, including: + - 'num_obs_samples': Number of observations generated per variable. + - 'snr': Signal-to-noise ratio applied to the observations. + causal_graph (nx.DiGraph): A randomly generated DAG representing the underlying + causal relationships among variables. Represented as nx.DiGraph. + gold_label (Answer): The gold label for the sample. + num_not_ctrl (Optional[int]): The number of variables not controlled for. None + if the hypothesis is invalid. + """ + + variable_metadata: Dict + hypotheses: nx.DiGraph + target_hypothesis: Tuple[str, str] + sample_metadata: Dict + causal_graph: nx.DiGraph + gold_label: Answer + num_not_ctrl: Optional[int] diff --git a/evals/elsuite/identifying_variables/utils.py b/evals/elsuite/identifying_variables/utils.py new file mode 100644 index 0000000000..6918926bdf --- /dev/null +++ b/evals/elsuite/identifying_variables/utils.py @@ -0,0 +1,91 @@ +import re +from typing import Dict + +import networkx as nx +import numpy as np + +from evals.elsuite.identifying_variables.structs import Answer, Sample +from evals.solvers.solver import SolverResult + + +def parse_solver_preds(solver_result: SolverResult) -> Answer: + solver_string = solver_result.output.strip().lower() + + pattern = ( + r"\[@answer " # Matches the beginning of the answer + r"valid_hyp: (true|false|True|False)" # valid hyp part + r"(?:; independent: ([^;]*))?" # Optionally matches the independent part + r"(?:; dependent: ([^;]*))?" # Optionally matches the dependent part + r"(?:; control: ([^\]]*))?" # Optionally matches the control part + r"\]" # Matches the end of the answer + ) + + match = re.search(pattern, solver_string) + + if match: + valid_hyp = match.group(1).lower() == "true" + if not valid_hyp: + return Answer( + valid_hypothesis=False, + ind_var=None, + dep_var=None, + ctrl_vars=None, + ) + ind_var = match.group(2) + ind_var = ind_var if ind_var is not None else "WRONG" + dep_var = match.group(3) + dep_var = dep_var if dep_var is not None else "WRONG" + ctrl_vars = match.group(4) + if ctrl_vars is not None: + ctrl_vars = ctrl_vars.split(",") + ctrl_vars = [var.strip() for var in ctrl_vars] + if ctrl_vars[0].lower().strip("\"'`«»<>") == "none": + ctrl_vars = [] + else: + ctrl_vars = ["WRONG"] + return Answer( + valid_hypothesis=True, + ind_var=ind_var, + dep_var=dep_var, + ctrl_vars=ctrl_vars, + ) + else: + raise ValueError("Invalid solver output") + + +def sample_serializer(obj): + """ + Custom serializer to pass to json.dumps when + saving a sample dictionary to jsonl + """ + if isinstance(obj, set): + return list(obj) + elif isinstance(obj, nx.DiGraph): + return nx.to_dict_of_lists(obj) + elif isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + + +def json_to_sample(serialized_sample: Dict) -> Sample: + """Reads sample from jsonl into Sample dataclass""" + hypotheses = nx.from_dict_of_lists(serialized_sample["hypotheses"], create_using=nx.DiGraph) + causal_graph = nx.from_dict_of_lists(serialized_sample["causal_graph"], create_using=nx.DiGraph) + gold_label = Answer(**serialized_sample["gold_label"]) + + # convert corrs in variable_metadata from lists to sets + for var in serialized_sample["variable_metadata"]: + serialized_sample["variable_metadata"][var]["corrs"] = set( + serialized_sample["variable_metadata"][var]["corrs"] + ) + + return Sample( + variable_metadata=serialized_sample["variable_metadata"], + hypotheses=hypotheses, + target_hypothesis=serialized_sample["target_hypothesis"], + sample_metadata=serialized_sample["sample_metadata"], + causal_graph=causal_graph, + gold_label=gold_label, + num_not_ctrl=serialized_sample["num_not_ctrl"], + ) diff --git a/evals/elsuite/incontext_rl/README.md b/evals/elsuite/incontext_rl/README.md new file mode 100644 index 0000000000..69dbde303b --- /dev/null +++ b/evals/elsuite/incontext_rl/README.md @@ -0,0 +1,74 @@ +# In-Context RL + +This eval tests models' ability to solve RL environments simply by interacting with them in-context, without dedicated training or fine-tuning. + +## Usage + +Run with: + +```bash +oaieval incontext_rl +``` + +For examples of tested solvers, see [`./scripts/run_experiments.sh`](./scripts/run_experiments.sh). + +## Dataset + +The eval is currently set up to test models on the following canonical RL environments: +1. [FrozenLake-v1](https://gymnasium.farama.org/environments/toy_text/frozen_lake/) (non-slippery version, default map), 4x4 gridworld where the agent has to reach the goal without falling into traps. +2. [CliffWalking-v0](https://gymnasium.farama.org/environments/toy_text/cliff_walking/). 4x12 gridworld where the agent has to reach the other side of the map without falling off a cliff. +3. [BanditTwoArmedHighLowFixed-v1](https://github.com/james-aung/gymasium-bandits). Stochastic two-armed bandit setup where Arm 1 pays out 80% of the time with reward 1, and Arm 2 pays out 20% of the time with reward 1. +4. [BanditTenArmedRandomFixed-v1](https://github.com/james-aung/gymasium-bandits). Stochastic ten-armed bandit setup where each arm has some randomly-initialized probability of payout. + +Besides these four environments, our eval is also built to be compatible with any environments that have discrete action and observation spaces using the Gymnasium API. Future work may generalize our eval to work with environments with other types of action/observation spaces. + +## Evaluation Process + +Each run of the eval tests the model on all four environments in the dataset, and has the model take steps in each environment until 200 steps are taken or the model’s context limit is reached. + +At each step, the eval provides the following to the model: +- The next observation and the reward from the last action. The model is also told when the environment has reset due to its action leading to a termination. +- How many of the maximum number of steps it has already taken. +- The total reward it has accumulated so far across all episodes. + +If an episode ends, the environment resets and a new episode begins. + +If the eval receive 4 responses in a row where we cannot parse an action selection, we end the evaluation for that environment. (This provides a natural end for runs where the model’s context window is exceeded.) + + +## Prompts + +We refer readers to the [`./defaults.py`](./defaults.py) file for the `TASK_DESCRIPTION` and other prompts used in the eval. + +## Metrics + +We provide the following metrics per evaluated environment: + +| **Metric** | **Notes** | +|---|---| +| `average_episode_reward` | The average reward achieved per episode | +| `total_steps` | The number of steps taken across all episodes before the environment sample ended | +| `invalid_response_rate` | % of responses that were in an invalid format for the eval | + + +## Token Usage Estimates + + +| Model | Token Usage Per Run | +|---|---| +| **gpt-3.5-turbo** | 4200000 ± 400000 | +| **gpt-4-turbo-preview** | 21900000 ± 10100000 | +| **mixtral-8x7b** | 2700000 ± 800000 | + + +## Future modifications + +- Extend the eval to work with other observation and action spaces beyond Discrete spaces + +## Version History + +- v0: Initial version released + +## Contribution Statement + +Eval design, implementation, and results evaluation were primarily conducted by James Aung. Chan Jun Shern was responsible for code reviews throughout the implementation process, along with fine-grained feedback on the project in general. Additional guidance was provided by Steven Adler, who scoped and managed the broader research project, including input on evaluation design, results analysis, and interpretation. \ No newline at end of file diff --git a/evals/elsuite/incontext_rl/anti-cot_solver.py b/evals/elsuite/incontext_rl/anti-cot_solver.py new file mode 100644 index 0000000000..40b3997e3c --- /dev/null +++ b/evals/elsuite/incontext_rl/anti-cot_solver.py @@ -0,0 +1,38 @@ +from typing import Any +from evals.solvers.solver import NestedSolver, Solver, SolverResult, SolverSpec +from evals.task_state import Message, TaskState + +ANTI_COT_TEMPLATE = "RESPOND ONLY WITH YOUR FINAL ANSWER IN THE FORMAT REQUESTED. DO NOT OUTPUT ANY ADDITIONAL REASONING OR TEXT." + +class AntiCoTSolver(NestedSolver): + """ + Instructs the model to not do any further reasoning and just respond with the final answer. + """ + + def __init__( + self, + solver: SolverSpec, + registry: Any = None, + ): + super().__init__(solver=solver) + + @property + def solver(self) -> Solver: + return self.get_solver("solver") + + def _solve( + self, + task_state: TaskState, + **kwargs, + ) -> SolverResult: + task_state.messages += ( + [ + Message(role="system", content=ANTI_COT_TEMPLATE), + ] + ) + solver_result = self.solver(task_state=task_state, **kwargs) + return solver_result + + @property + def name(self) -> str: + return f"Anti-CoT_{self.solver.name}" diff --git a/evals/elsuite/incontext_rl/baselines.py b/evals/elsuite/incontext_rl/baselines.py new file mode 100644 index 0000000000..34f65d6caf --- /dev/null +++ b/evals/elsuite/incontext_rl/baselines.py @@ -0,0 +1,118 @@ +import random + +import numpy as np + +from evals.elsuite.incontext_rl.eval import CurrentState +from evals.record import record_sampling +from evals.solvers.solver import Solver, SolverResult +from evals.task_state import TaskState + + +class RandomSolver(Solver): + def __init__(self, *args, **kwargs): + pass + + def _solve( + self, + task_state: TaskState, + **kwargs, + ) -> SolverResult: + + cs: CurrentState = task_state.current_state + + try: + action = cs.action_space.sample() + response = f"[SELECT: {action}]" + except Exception as e: + response = f"Error: {e}" + + record_sampling( + prompt=cs.observations[-1], + sampled=response, + model="incontext_rl_random", + ) + + return SolverResult(response) + + +class QlearningSolver(Solver): + def __init__( + self, + learning_rate=0.7, + gamma=0.95, + epsilon=1.0, + min_epsilon=0.05, + max_epsilon=1.0, + decay_rate=0.0005, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.learning_rate = learning_rate + self.gamma = gamma + self.epsilon = epsilon + self.min_epsilon = min_epsilon + self.max_epsilon = max_epsilon + self.decay_rate = decay_rate + self.q_table = None + + def initialize_q_table(self, observation_space_size, action_space_size): + self.q_table = np.zeros((observation_space_size, action_space_size)) + + def select_action(self, state, action_space): + if random.uniform(0, 1) < self.epsilon: + return action_space.sample() # Explore action space + else: + return np.argmax(self.q_table[state][:]) # Exploit learned values + + def update_q_table(self, state, action, reward, next_state): + next_max = np.max(self.q_table[next_state]) + self.q_table[state, action] = self.q_table[state, action] + self.learning_rate * ( + reward + self.gamma * next_max - self.q_table[state, action] + ) + + def reduce_epsilon(self, episode_number): + self.epsilon = self.min_epsilon + (self.max_epsilon - self.min_epsilon) * np.exp( + -self.decay_rate * episode_number + ) + + def _solve(self, task_state: TaskState, **kwargs) -> SolverResult: + + cs: CurrentState = task_state.current_state + + # TODO these might not be true if environment is not discrete + assert ( + cs.observation_space_n is not None + ), "Environment must have discrete observation space" + assert cs.action_space_n is not None, "Environment must have discrete action space" + + if self.q_table is None: + print("Initializing Q-table") + self.initialize_q_table( + observation_space_size=cs.observation_space_n, action_space_size=cs.action_space_n + ) + + # This shouln't run on the first step + if len(cs.actions) >= 1 and len(cs.rewards) >= 1 and len(cs.observations) >= 2: + print(cs.actions) + self.update_q_table( + state=cs.observations[-2], + action=cs.actions[-1], + reward=cs.rewards[-1], + next_state=cs.observations[-1], + ) + print( + f"The last action {cs.actions[-1]} resulted in reward {cs.rewards[-1]}. We went from state {cs.observations[-2]} to state {cs.observations[-1]}" + ) + self.reduce_epsilon(episode_number=len(cs.episode_end_steps)) + + action = self.select_action(state=cs.observations[-1], action_space=cs.action_space) + response = f"[SELECT: {action}]" + + record_sampling( + prompt=cs.observations[-1], + sampled=response, + model="incontext_rl_qlearning", + ) + + return SolverResult(response) diff --git a/evals/elsuite/incontext_rl/defaults.py b/evals/elsuite/incontext_rl/defaults.py new file mode 100644 index 0000000000..2f1cb06fc2 --- /dev/null +++ b/evals/elsuite/incontext_rl/defaults.py @@ -0,0 +1,30 @@ +from string import Template + +task_description_template = Template("""$explanations + +You have the following $action_space actions available to you: $actions. + +On each turn, you will receive an observation from the environment, which maybe be one of the following $observation_space possible observations: +$observations +Each action you take will give you a reward (which might be negative) and may move you to a new state. + +You should take actions such that you accumulate the highest possible reward across all your actions. + +This is how we will play the game: +1. You will consider all the information you have received to decide upon the best action you can take now. +2. You will select your action from the list above by specifying the number key of the action in the command [SELECT: x], where x is the number key of the action. +3. Your selected action will be taken. +4. As a result of your action, you will be given an observation from the environment and you may receive some reward. +5. Repeat from step 1. +""") + +step_counter = Template("Total actions taken so far: $step_count") +reward_counter = Template("Total reward so far: $reward_count") +reset_msg = Template("""After the game reset you are now in $observation. +Please pick an action, providing your reasoning. You must format your final action choice as [SELECT: x]""") +step_result = Template("""You took Action $action. You are now in $next_observation. +The last step you did provided reward: $reward. +Please pick an action, providing your reasoning. You must format your final action choice as [SELECT: x]""") +step_result_reset = Template("""You took Action $action. You arrived at $next_observation. +The last step made the game reset. +The last step you did provided reward: $reward.""") \ No newline at end of file diff --git a/evals/elsuite/incontext_rl/env_setup.py b/evals/elsuite/incontext_rl/env_setup.py new file mode 100644 index 0000000000..31ffcba534 --- /dev/null +++ b/evals/elsuite/incontext_rl/env_setup.py @@ -0,0 +1,12 @@ +""" +Optional setup scripts for specific environments. +""" + +def setup_GymnasiumBandits(): + import gymnasium_bandits + return + +ENV_SETUP_FUNCS = { + "BanditTwoArmedHighLowFixed-v0": setup_GymnasiumBandits, + "BanditTenArmedRandomFixed-v0": setup_GymnasiumBandits, +} \ No newline at end of file diff --git a/evals/elsuite/incontext_rl/eval.py b/evals/elsuite/incontext_rl/eval.py new file mode 100644 index 0000000000..a1fac2101e --- /dev/null +++ b/evals/elsuite/incontext_rl/eval.py @@ -0,0 +1,299 @@ +import logging +import random +import re +from dataclasses import dataclass, field +from typing import Any, List, Optional + +import gymnasium as gym +import numpy as np + +import evals +from evals.api import CompletionFn +from evals.elsuite.incontext_rl.defaults import ( + reset_msg, + reward_counter, + step_counter, + step_result, + step_result_reset, + task_description_template, +) +from evals.elsuite.incontext_rl.env_setup import ENV_SETUP_FUNCS +from evals.eval import SolverEval +from evals.solvers.solver import Solver +from evals.task_state import Message, TaskState + +logger = logging.getLogger(__name__) + + +@dataclass +class CurrentState: + action_space: gym.Space + observation_space: gym.Space + action_space_n: int + observation_space_n: int + invalid_responses: int = 0 + total_responses: int = 0 + actions: List = field(default_factory=list) + rewards: List[float] = field(default_factory=list) + observations: List = field(default_factory=list) + episode_end_steps: List[int] = field(default_factory=list) + + +class InContextRl(SolverEval): + def __init__( + self, + completion_fns: list[CompletionFn], + max_steps: int = 200, # maximum possible steps per sample, optional + max_invalid_responses: int = 4, # maximum invalid responses from Solver before terminating sample + max_num_messages_allowed: int = 2048, # maximum number of messages allowed by OpenAI API + use_explanations: bool = False, # Whether to include a key for how to understand action and observation spaces + *args, + **kwargs, + ): + super().__init__(completion_fns, *args, **kwargs) + self.max_steps = max_steps + self.max_invalid_responses = max_invalid_responses + self.use_explanations = use_explanations + self.max_num_messages_allowed = max_num_messages_allowed + + def eval_sample(self, solver: Solver, sample: Any, rng: random.Random): + + # Validate sample + required_keys = ["env", "env_id", "explanations"] + assert all( + key in sample for key in required_keys + ), f"Sample missing required keys: {required_keys}" + assert isinstance(sample["env"], gym.Env) + assert isinstance(sample["env_id"], str) + assert isinstance(sample["explanations"], str) + + env = sample["env"] + ts = TaskState( + task_description=self._generate_task_description(env, sample), + messages=[], + current_state=CurrentState( + action_space=env.action_space, + observation_space=env.observation_space, + action_space_n=env.action_space.n, # TODO might not be available for all envs, check when adding a continuous env + observation_space_n=env.observation_space.n, # TODO might not be available for all envs, check when adding a continuous env + ), + ) + + # Reset environment and update task state + observation, _ = env.reset(seed=42) + ts.current_state.observations.append(observation) + + # Tell model starting observation and ask it to pick an action + self._add_reset_message_to_task_state(ts, observation, sample) + + for _ in range(self.max_steps): + self._add_recap_message_to_task_state( + ts, ts.current_state.actions, ts.current_state.rewards + ) + + action = self._try_get_valid_action(solver, ts, env.action_space.n) + + if action is None: + logger.info("Ending sample since couldn't parse an action.") + break + else: + next_observation, reward, terminated, truncated, _ = env.step(action) + ts.current_state.actions.append(action) + ts.current_state.rewards.append(float(reward)) + ts.current_state.observations.append(next_observation) + + if terminated or truncated: + # Tell model that episode ended and what reward was received + content = self._format_step_message( + action, next_observation, reward, sample, terminated=True + ) + ts.messages += [Message(role="user", content=content)] + + # Log what step the episode ended on + ts.current_state.episode_end_steps.append(len(ts.current_state.actions)) + + # Reset environment + observation, _ = env.reset(seed=42) + ts.current_state.observations.append(observation) + + # Tell model new observation after reset and ask it to pick an action + self._add_reset_message_to_task_state(ts, observation, sample) + else: + content = self._format_step_message(action, next_observation, reward, sample) + ts.messages += [Message(role="user", content=content)] + + env.close() + + episode_rewards = self._calculate_episode_rewards( + ts.current_state.episode_end_steps, ts.current_state.rewards + ) + evals.record.record_metrics( + environment=f"{env.spec.id} {env.spec.kwargs}", + explanations=self.use_explanations, + total_return=sum(ts.current_state.rewards), + total_steps=len(ts.current_state.actions), + actions=ts.current_state.actions, + rewards=ts.current_state.rewards, + episode_rewards=episode_rewards, + average_episode_reward=float(np.mean(episode_rewards)), + average_reward_last_5_episodes=float(np.mean(episode_rewards[-5:])), + average_reward_last_10_episodes=float(np.mean(episode_rewards[-10:])), + average_reward_last_20_episodes=float(np.mean(episode_rewards[-20:])), + average_reward_last_50_episodes=float(np.mean(episode_rewards[-50:])), + invalid_response_rate=ts.current_state.invalid_responses + / ts.current_state.total_responses + if ts.current_state.total_responses > 0 + else 0, + episode_end_steps=ts.current_state.episode_end_steps, + ) + + def run(self, recorder: evals.record.Recorder): + samples = self.get_samples() + for sample in samples: + # Create environments and pass them to each thread via the sample + # (gym envs don't like being created in the thread itself) + sample["env"] = self._make_env(sample) + self.eval_all_samples(recorder, samples) + + metrics = recorder.get_metrics() + + results = [] + + for metric in metrics: + env_result = { + "env": metric["environment"], + "metrics": { + "explanations": metric["explanations"], + "average_episode_reward": metric["average_episode_reward"], + "average_reward_last_5_episodes": metric["average_reward_last_5_episodes"], + "average_reward_last_10_episodes": metric["average_reward_last_10_episodes"], + "average_reward_last_20_episodes": metric["average_reward_last_20_episodes"], + "average_reward_last_50_episodes": metric["average_reward_last_50_episodes"], + "episode_rewards": metric["episode_rewards"], + "total_return": metric["total_return"], + "total_steps": metric["total_steps"], + "actions": metric["actions"], + "rewards": metric["rewards"], + "invalid_response_rate": metric["invalid_response_rate"], + "episode_end_steps": metric["episode_end_steps"], + }, + } + results.append(env_result) + + final_result = {"environments": results} + return final_result + + def _make_env(self, sample: dict) -> gym.Env: + env_id = sample["env_id"] + env_args = sample.get("env_args", {}) + if env_id in ENV_SETUP_FUNCS: + # Optional setup scripts for specific environments + ENV_SETUP_FUNCS[env_id]() + return gym.make(env_id, **env_args) + + def _generate_task_description(self, env: gym.Env, sample: dict) -> str: + + actions = [str(action) for action in range(env.action_space.n)] + observations = [ + f"Observation {observation}" for observation in range(env.observation_space.n) + ] + explanations = ( + sample["explanations"] if self.use_explanations else "You are playing a game." + ) + + return task_description_template.substitute( + action_space=env.action_space.n, + actions=actions, + observation_space=env.observation_space.n, + observations=observations, + explanations=explanations, + ) + + def _try_get_valid_action( + self, solver: Solver, task_state: TaskState, action_space: int + ) -> Optional[int]: + number_of_attempts = 0 + while number_of_attempts < self.max_invalid_responses: + if len(task_state.messages) > self.max_num_messages_allowed: + logger.info( + f"Exceeded maximum number of messages allowed ({self.max_num_messages_allowed})." + ) + return None + solver_response = solver(task_state).output + action = self._parse_action(solver_response) + task_state.messages += [Message(role="assistant", content=solver_response)] + task_state.current_state.total_responses += 1 + # Check if action is valid + if action not in range( + action_space + ): # TODO this might not work for non-discrete action spaces, check with more complex env + task_state.messages += [ + Message( + role="user", + content="Invalid action. Please provide ONE valid action by outputting your selection in the format [SELECT: x]. Only output this selection ONCE.", + ) + ] + task_state.current_state.invalid_responses += 1 + number_of_attempts += 1 + else: + return action + # If the loop exits due to reaching max invalid attempts, log and return None + logger.info(f"Exceeded maximum invalid action attempts ({self.max_invalid_responses}).") + return None + + def _parse_action(self, raw_response: str) -> Optional[int]: + pattern = r"\[SELECT: (\d+)\]" + matches = re.findall(pattern, raw_response) + + actions = [int(match) for match in matches] + if not actions: + logger.info(f"No action selections found in response: {raw_response}") + return None + if len(actions) > 1: + logger.info(f"Multiple action selections found in response: {raw_response}") + return None + return actions[0] + + def _add_message_to_task_state(self, task_state: TaskState, role: str, content: str) -> None: + """ + Adds a message to the task state, combining it with the previous message if they are from the same role. + """ + if task_state.messages and task_state.messages[-1].role == role: + task_state.messages[-1].content += "\n\n" + content + else: + task_state.messages.append(Message(role=role, content=content)) + + def _add_reset_message_to_task_state( + self, task_state: TaskState, observation: int, sample: dict + ) -> None: + content = reset_msg.substitute(observation=f"Observation {observation}") + self._add_message_to_task_state(task_state, "user", content) + + def _add_recap_message_to_task_state( + self, task_state: TaskState, actions: List, rewards: List[float] + ) -> None: + step_count = step_counter.substitute(step_count=len(actions)) + reward_count = reward_counter.substitute(reward_count=sum(rewards)) + content = "\n".join([step_count, reward_count]) + self._add_message_to_task_state(task_state, "user", content) + + def _format_step_message( + self, action: int, observation: int, reward: float, sample: dict, terminated: bool = False + ) -> str: + observation_desc = f"Observation {observation}" + if terminated: + template = step_result_reset + else: + template = step_result + return template.substitute(action=action, next_observation=observation_desc, reward=reward) + + def _calculate_episode_rewards(self, episode_end_steps, rewards): + episode_rewards = [] + if not episode_end_steps: # Handle case where there was only 1 episode + return [sum(rewards)] + start_index = 0 + for end_index in episode_end_steps: + episode_reward = sum(rewards[start_index:end_index]) + episode_rewards.append(episode_reward) + start_index = end_index + return episode_rewards diff --git a/evals/elsuite/incontext_rl/requirements.txt b/evals/elsuite/incontext_rl/requirements.txt new file mode 100644 index 0000000000..2712d1140b --- /dev/null +++ b/evals/elsuite/incontext_rl/requirements.txt @@ -0,0 +1,3 @@ +# Additional requirements for specific environments +gymnasium +git+https://github.com/james-aung/gymnasium-bandits \ No newline at end of file diff --git a/evals/elsuite/incontext_rl/scripts/plot_experiments.py b/evals/elsuite/incontext_rl/scripts/plot_experiments.py new file mode 100644 index 0000000000..9e8e27f82b --- /dev/null +++ b/evals/elsuite/incontext_rl/scripts/plot_experiments.py @@ -0,0 +1,363 @@ +import json +import numpy as np +import matplotlib.pyplot as plt +from scipy.stats import sem +import pandas as pd +from pathlib import Path +import matplotlib.colors as mcolors +import argparse +import seaborn as sns + +from evals.utils.log_utils import extract_spec, get_final_results_from_dir + +WINDOW_SIZES = { + "FrozenLake-v1 {'map_name': '4x4', 'is_slippery': False}": 20, + "BanditTwoArmedHighLowFixed-v0 {}": 40, + "BanditTenArmedRandomFixed-v0 {}": 40, + "CliffWalking-v0 {}": 20, + "FrozenLake-v1 {'map_name': '4x4', 'is_slippery': False, 'desc': ['SHFF', 'FFFF', 'FFGH', 'HFHF']}": 20, + "default": 20, +} + +PRETTY_MODEL_NAMES = { + 'generation/direct/gpt-4-turbo-preview': 'GPT-4 Turbo Preview', + 'incontext_rl/random': 'Random Strategy', + 'generation/direct/gpt-3.5-turbo': 'GPT-3.5 Turbo', + 'incontext_rl/qlearning_scratch': 'Q-Learning from scratch', + 'incontext_rl/qlearning_trained': 'Q-Learning trained', + 'generation/direct/gemini-pro': 'Gemini Pro 1.0', + 'generation/direct/mixtral-8x7b-instruct': 'Mixtral 8x7b', +} + +PRETTY_ENV_TITLES = { + "FrozenLake-v1 {'map_name': '4x4', 'is_slippery': False}": 'Frozen Lake (4x4, Non-slippery)', + "BanditTwoArmedHighLowFixed-v0 {}": "Two-Armed Bandit", + "BanditTenArmedRandomFixed-v0 {}": "Ten-Armed Bandit", + "CliffWalking-v0 {}": "Cliff Walking", + "FrozenLake-v1 {'map_name': '4x4', 'is_slippery': False, 'desc': ['SFFF', 'FHFH', 'FFFH', 'GFFH']}": 'Frozen Lake Custom Map (4x4, Non-slippery)', +} + +MODEL_STYLES = { + 'generation/direct/gpt-4-turbo-preview': {'line_style': '-', 'color': 'purple', 'alpha': 0.7}, + 'incontext_rl/random': {'line_style': ':', 'color': 'grey', 'alpha': 0.7}, + 'generation/direct/gpt-3.5-turbo': {'line_style': '-', 'color': 'green', 'alpha': 0.7}, + 'incontext_rl/qlearning_scratch': {'line_style': '--', 'color': 'grey', 'alpha': 0.7}, + 'incontext_rl/qlearning_trained': {'line_style': '-', 'color': 'black', 'alpha': 0.7}, + 'generation/direct/gemini-pro': {'line_style': '-', 'color': 'blue', 'alpha': 0.7}, + 'generation/direct/mixtral-8x7b-instruct': {'line_style': '-', 'color': 'orange', 'alpha': 0.7}, + 'default': {'line_style': '-', 'color': 'black', 'alpha': 0.5}, +} + +def calculate_episode_rewards(row: pd.Series) -> list: + """ + Calculate the rewards for each episode based on the episode end steps and rewards. + """ + episode_end_steps = row['episode_end_steps'] + rewards = row['rewards'] + episode_rewards = [] + if not episode_end_steps: # Handle case where there was only 1 episode + return [sum(rewards)] + start_index = 0 + for end_index in episode_end_steps: + episode_reward = sum(rewards[start_index:end_index]) + episode_rewards.append(episode_reward) + start_index = end_index + return episode_rewards + +def calculate_rolling_average(episode_rewards: list, window_size: int) -> list: + """ + Calculate the rolling average of the episode rewards using a specified window size. + """ + window_size = int(window_size) + rolling_averages = [] + for i in range(len(episode_rewards)): + # Calculate the start index for the window; ensure it's not negative + start_index = max(0, i - window_size + 1) + # Calculate the running average for the current window + window_average = np.mean(episode_rewards[start_index:i+1]) + rolling_averages.append(window_average) + return rolling_averages + +def calculate_custom_episode_end_steps_for_cliffwalking(rewards: list, existing_end_steps: list) -> list: + """ + Calculate episode end steps based on rewards and append to existing end steps. + An episode also ends when the reward is -100 i.e. when the agent falls off the cliff. + + Args: + rewards (list): List of rewards for each step in an episode. + existing_end_steps (list): List of already identified episode end steps. + + Returns: + list: Updated list of indices representing the end of each episode. + """ + new_end_steps = [i + 1 for i, reward in enumerate(rewards) if reward == -100] + # Combine existing and new end steps, remove duplicates, and sort + combined_end_steps = sorted(set(existing_end_steps + new_end_steps)) + return combined_end_steps + +def extract_results(datadir: Path) -> pd.DataFrame: + """ + Extracts results from the specified directory and returns a DataFrame. + + Args: + datadir (Path): Path to the directory containing the experiment results. + + Returns: + pd.DataFrame: DataFrame containing the experiment results. + """ + print(f"Extracting results from directory: {datadir}") + df_rows = [] + final_results = get_final_results_from_dir(datadir) + if not final_results: + print("No results found in directory.") + raise ValueError("No results found in directory.") + + for path, results in final_results.items(): + print(f"Processing file: {path}") + spec = extract_spec(path) + if not spec: + raise ValueError(f"No spec found for {path}") + model = spec.get("completion_fns", [None])[0] + base_eval = spec.get("base_eval") + if not model or base_eval is None: + raise ValueError(f"Missing model or base_eval in spec for {path}") + + environments = results.get('environments', []) + for env in environments: + metrics = env.get('metrics', {}) + flattened_metrics = {f"{k}": v for k, v in metrics.items()} # Flatten metrics into separate columns + print(f"Extracted {env['env']} metrics for model: {model}") + + # Calculate custom episode end steps for CliffWalking environment + if env['env'] == "CliffWalking-v0 {}": + rewards = metrics.get('rewards', []) + existing_end_steps = metrics.get('episode_end_steps', []) + episode_end_steps = calculate_custom_episode_end_steps_for_cliffwalking(rewards, existing_end_steps) + flattened_metrics['episode_end_steps'] = episode_end_steps + + df_rows.append({"model": model, "base_eval": base_eval, "environment": env['env'], **flattened_metrics}) + + df = pd.DataFrame(df_rows) + + if 'episode_rewards' not in df.columns: + df['episode_rewards'] = df.apply(calculate_episode_rewards, axis=1) + + # For plots + df['cumulative_episode_rewards'] = df['episode_rewards'].apply(np.cumsum) + df['average_episode_reward'] = df['episode_rewards'].apply(np.mean) + df['window_size'] = df['environment'].map(WINDOW_SIZES).fillna(WINDOW_SIZES.get('default', 20)) + df['rolling_average_rewards'] = df.apply(lambda row: calculate_rolling_average(row['episode_rewards'], row['window_size']), axis=1) + + # We also calculate the rolling average across different window sizes + df['rolling_average_rewards_5_episodes'] = df.apply(lambda row: calculate_rolling_average(row['episode_rewards'], 5), axis=1) + df['rolling_average_rewards_10_episodes'] = df.apply(lambda row: calculate_rolling_average(row['episode_rewards'], 10), axis=1) + df['rolling_average_rewards_20_episodes'] = df.apply(lambda row: calculate_rolling_average(row['episode_rewards'], 20), axis=1) + df['rolling_average_rewards_50_episodes'] = df.apply(lambda row: calculate_rolling_average(row['episode_rewards'], 50), axis=1) + + # We also calculate the average reward for the last 5, 10, 20, and 50 episodes. For older runs, we may not have this information. + if 'average_reward_last_5_episodes' not in df.columns: + df['average_reward_last_5_episodes'] = df['episode_rewards'].apply(lambda rewards: np.mean(rewards[-5:])) + if 'average_reward_last_10_episodes' not in df.columns: + df['average_reward_last_10_episodes'] = df['episode_rewards'].apply(lambda rewards: np.mean(rewards[-10:])) + if 'average_reward_last_20_episodes' not in df.columns: + df['average_reward_last_20_episodes'] = df['episode_rewards'].apply(lambda rewards: np.mean(rewards[-20:])) + if 'average_reward_last_50_episodes' not in df.columns: + df['average_reward_last_50_episodes'] = df['episode_rewards'].apply(lambda rewards: np.mean(rewards[-50:])) + + print(f"Extraction complete. {len(df_rows)} rows in DataFrame.") + return df + +def plot_rewards(df, environment, reward_type, out_dir, window_size=None): + """ + Generalized function to plot episode, cumulative, or running average rewards for different models + on the same graph for a specific environment. It automatically determines the plot type (line or scatter) + based on the number of episodes and includes the 95% confidence intervals for line plots. + + Args: + df (pd.DataFrame): DataFrame containing the experiment results. + environment (str): Name of the environment to plot. + reward_type (str): Type of reward to plot. Must be one of 'episode_rewards', 'cumulative_episode_rewards', or 'rolling_average_rewards'. + out_dir (Path): Path to the directory to save the plots. + window_size (int): Window size for calculating rolling averages. If None, the window size will be determined based on the environment. + """ + valid_reward_types = ['episode_rewards', 'cumulative_episode_rewards', 'rolling_average_rewards'] + if reward_type not in valid_reward_types: + raise ValueError(f"Invalid reward_type. Expected one of {valid_reward_types}, got {reward_type}") + + # Filter the DataFrame for the specific environment + filtered_df = df[df['environment'] == environment] + + # Explode the specified reward list into separate rows and prepare for plotting + rewards_df = filtered_df.explode(reward_type).reset_index() # Each row will be a single episode + rewards_df['episode'] = rewards_df.groupby(['model', 'index']).cumcount() + 1 # Add episode number as a column + rewards_df['reward'] = rewards_df[reward_type] # Rename the column for clarity + + truncate_per_model = True + if environment == "CliffWalking-v0 {}": + truncate_per_model = False # Hacky workaround to make better plots since some models only have 1 episode on CliffWalking + + if truncate_per_model: + filtered_rewards_df = pd.DataFrame() + for model, group in rewards_df.groupby('model'): + # Count the number of runs for each episode number + episode_counts = group.groupby('episode').size() + # Check if there are at least 3 runs for any episode number + if episode_counts.max() >= 3: + # Find the maximum episode number where at least 3 runs are available + max_episode_with_at_least_3_runs = episode_counts[episode_counts >= 3].index.max() + # Filter the group DataFrame to only include data up to this episode number + model_filtered = group[group['episode'] <= max_episode_with_at_least_3_runs] + else: + # If there are fewer than 3 runs for all episodes, include all data for this model + model_filtered = group + # Append the filtered data for the current model to the overall filtered DataFrame + filtered_rewards_df = pd.concat([filtered_rewards_df, model_filtered], ignore_index=True) + rewards_df = filtered_rewards_df + + plt.figure(figsize=(10, 5)) + ax = plt.gca() + + # Determine the plot type based on the number of episodes + num_episodes = len(rewards_df['episode'].unique()) + if num_episodes > 1: + # Iterate over each unique model in the DataFrame + for model in rewards_df['model'].unique(): + # Filter the DataFrame for the current model + model_df = rewards_df[rewards_df['model'] == model] + # Get the custom style for the current model using the helper function + custom_style = MODEL_STYLES.get(model, MODEL_STYLES['default']) + pretty_model_name = PRETTY_MODEL_NAMES.get(model, model) + # Plot the data for the current model on the same axes with custom settings + lineplot = sns.lineplot(data=model_df, x='episode', y='reward', estimator='mean', errorbar=('ci', 95), + linestyle=custom_style['line_style'], color=custom_style['color'], + alpha=custom_style['alpha'], label=pretty_model_name, ax=ax, + err_kws={'alpha': 0.035}) + # Add labels to the final value on the x axis + for line in lineplot.get_lines(): + x, y = line.get_data() + if len(x) > 0: # Check if there is data to plot + ax.text(x[-1], y[-1], f"{y[-1]:.2f}", color=line.get_color(), fontsize=9) + else: + # For a single episode, use scatter plot, differentiating models by color + scatterplot = sns.scatterplot(data=rewards_df, x='episode', y='reward', hue='model', ax=ax) + # Add labels to the final value on the x axis + for line in scatterplot.collections: + offsets = line.get_offsets() + if offsets.size > 0: # Check if there are points to plot + last_point = offsets[-1] + ax.text(last_point[0], last_point[1], f"{last_point[1]:.2f}", fontsize=9) + + pretty_env_title = PRETTY_ENV_TITLES.get(environment, environment) + plt.title(f'{reward_type.replace("_", " ").title()} in {pretty_env_title} (Window Size: {window_size})' if reward_type == 'rolling_average_rewards' else f'{reward_type.replace("_", " ").title()} in {pretty_env_title}') + plt.xlabel('Episode') + plt.ylabel('Reward') + plt.legend(title='Model', bbox_to_anchor=(1.05, 1), loc='upper left') + plt.xlim(1, num_episodes) + plt.tight_layout() + plot_dir = out_dir / reward_type + plot_dir.mkdir(parents=True, exist_ok=True) + plt.savefig(plot_dir / f'{environment}.png') + plt.show() + +def calculate_rolling_averages(df: pd.DataFrame, max_items: int = 200): + """ + Calculate the averaged final and max rolling averages for the first N items in each model and environment. + Args: + df (pd.DataFrame): DataFrame containing the experiment results. + max_items (int): Maximum number of items to consider for calculating rolling averages. + Returns: + dict: Dictionary containing the averaged final and max rolling averages for each model and environment. + """ + + model_env_averages_info = {} + for model in df['model'].unique(): + model_df = df[df['model'] == model] + model_env_averages_info[model] = {} + all_final_rolling_averages = [] # To store all final rolling averages across environments for each model + for env in model_df['environment'].unique(): + env_df = model_df[model_df['environment'] == env] + # Determine the last shared episode across all runs for the current model and environment, limited to the first max_items items + max_shared_episode = min(max_items, env_df['rolling_average_rewards'].apply(lambda rewards: len(rewards[:max_items])).min()) + # Truncate each run's rolling_average_rewards to the max shared episode and then calculate averages + truncated_averages = env_df['rolling_average_rewards'].apply(lambda rewards: rewards[:max_shared_episode]) + final_rolling_averages = round(truncated_averages.apply(lambda rewards: rewards[-1] if len(rewards) > 0 else None).mean(), 2) + max_rolling_averages = round(truncated_averages.apply(lambda rewards: max(rewards) if len(rewards) > 0 else None).mean(), 2) + + all_final_rolling_averages.append(final_rolling_averages) # Append the final rolling average for the current environment + + model_env_averages_info[model][env] = { + 'average_final_rolling_averages': final_rolling_averages, + 'average_max_rolling_averages': max_rolling_averages, + } + + # Calculate the average final rolling average across all environments for the current model + average_final_across_envs = round(sum(all_final_rolling_averages) / len(all_final_rolling_averages), 2) if all_final_rolling_averages else None + model_env_averages_info[model]['average_final_rolling_averages_across_envs'] = average_final_across_envs + return model_env_averages_info + +def json_of_results(df: pd.DataFrame, out_dir: Path): + """ + JSON dump of the results. + + Each model will have the following information, grouping by environment: + - Average episode reward + - Last rolling average reward for each of 5, 10, 20, and 50 episodes + - Max rolling average reward across the 5, 10, 20, and 50 episodes + - Invalid response rate + + Where there are multiple runs for a model and environment, the average of the above values will be calculated. + """ + + model_info = {} + for model in df['model'].unique(): + model_df = df[df['model'] == model] + model_info[model] = {} + for env in model_df['environment'].unique(): + env_df = model_df[model_df['environment'] == env] + # Calculate the average rolling averages across all runs for each window size, then find the max + average_rolling_averages_5 = env_df['rolling_average_rewards_5_episodes'].apply(pd.Series).mean().max() + average_rolling_averages_10 = env_df['rolling_average_rewards_10_episodes'].apply(pd.Series).mean().max() + average_rolling_averages_20 = env_df['rolling_average_rewards_20_episodes'].apply(pd.Series).mean().max() + average_rolling_averages_50 = env_df['rolling_average_rewards_50_episodes'].apply(pd.Series).mean().max() + + model_info[model][env] = { + 'average_episode_reward': round(env_df['average_episode_reward'].mean(), 2), + 'average_reward_last_5_episodes': round(env_df['average_reward_last_5_episodes'].mean(), 2), + 'average_reward_last_10_episodes': round(env_df['average_reward_last_10_episodes'].mean(), 2), + 'average_reward_last_20_episodes': round(env_df['average_reward_last_20_episodes'].mean(), 2), + 'average_reward_last_50_episodes': round(env_df['average_reward_last_50_episodes'].mean(), 2), + 'max_rolling_average_rewards_5_episodes': round(average_rolling_averages_5, 2), + 'max_rolling_average_rewards_10_episodes': round(average_rolling_averages_10, 2), + 'max_rolling_average_rewards_20_episodes': round(average_rolling_averages_20, 2), + 'max_rolling_average_rewards_50_episodes': round(average_rolling_averages_50, 2), + 'invalid_response_rate': round(env_df['invalid_response_rate'].mean(), 2), + } + with open(out_dir / 'model_info.json', 'w') as f: + json.dump(model_info, f, indent=4) + +def main(log_dir: str = None, out_dir: str = None): + + parser = argparse.ArgumentParser() + parser.add_argument("--log_dir", "-d", type=str, required=not log_dir) + parser.add_argument("--out_dir", "-o", type=str, required=not out_dir) + args = parser.parse_args() + log_dir = Path(log_dir) if log_dir else Path(args.log_dir) + out_dir = Path(out_dir) if out_dir else Path(args.out_dir) + + # Extract results from directory + df = extract_results(log_dir) + + # # Plot episode rewards with 95% confidence intervals + for env in df['environment'].unique(): + plot_rewards(df, env, 'episode_rewards', out_dir) + plot_rewards(df, env, 'cumulative_episode_rewards', out_dir) + window_size = df[df['environment'] == env]['window_size'].iloc[0] + plot_rewards(df, env, 'rolling_average_rewards', out_dir, window_size) + + # JSON dump of the results + json_of_results(df, out_dir) + + +if __name__ == "__main__": + main() + diff --git a/evals/elsuite/incontext_rl/scripts/qlearning_baseline.ipynb b/evals/elsuite/incontext_rl/scripts/qlearning_baseline.ipynb new file mode 100644 index 0000000000..bb2dc7ae44 --- /dev/null +++ b/evals/elsuite/incontext_rl/scripts/qlearning_baseline.ipynb @@ -0,0 +1,402 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install gymnasium\n", + "!pip install numpy\n", + "!pip install git+https://github.com/james-aung/gymnasium-bandits\n", + "!pip install tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import gymnasium as gym\n", + "import random\n", + "import json\n", + "\n", + "import gymnasium_bandits" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "# Training parameters\n", + "n_training_episodes = 10000 # Total training episodes\n", + "n_training_steps = 200 # Total training steps\n", + "learning_rate = 0.7 # Learning rate\n", + "\n", + "# Evaluation parameters\n", + "reward_window_size = 25 # Number of steps to consider when calculating average reward\n", + "\n", + "# Environment parameters\n", + "gamma = 0.95 # Discounting rate\n", + "\n", + "# Exploration parameters\n", + "max_epsilon = 1.0 # Exploration probability at start\n", + "min_epsilon = 0.05 # Minimum exploration probability\n", + "decay_rate = 0.0005 # Exponential decay rate for exploration prob" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "def initialize_q_table(state_space, action_space):\n", + " Qtable = np.zeros((state_space, action_space))\n", + " return Qtable" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "def greedy_policy(Qtable, state):\n", + " # Exploitation: take the action with the highest state, action value\n", + " action = np.argmax(Qtable[state][:])\n", + "\n", + " return action" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "def epsilon_greedy_policy(Qtable, state, epsilon):\n", + " # Randomly generate a number between 0 and 1\n", + " random_num = random.uniform(0,1)\n", + " # if random_num > greater than epsilon --> exploitation\n", + " if random_num > epsilon:\n", + " # Take the action with the highest value given a state\n", + " # np.argmax can be useful here\n", + " action = greedy_policy(Qtable, state)\n", + " # else --> exploration\n", + " else:\n", + " action = env.action_space.sample()\n", + "\n", + " return action" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "def train(n_training_steps, min_epsilon, max_epsilon, decay_rate, env, Qtable, reward_window_size=25):\n", + "\n", + " actions, rewards = [], []\n", + " total_steps = 0\n", + " episode_end_steps = []\n", + "\n", + " for _ in range(n_training_steps):\n", + " if total_steps >= n_training_steps:\n", + " break\n", + " # Reduce epsilon (because we need less and less exploration)\n", + " epsilon = min_epsilon + (max_epsilon - min_epsilon)*np.exp(-decay_rate*len(episode_end_steps))\n", + " # Reset the environment\n", + " state, info = env.reset()\n", + " terminated = False\n", + " truncated = False\n", + "\n", + " while not terminated and not truncated and total_steps < n_training_steps:\n", + " # Choose the action At using epsilon greedy policy\n", + " action = epsilon_greedy_policy(Qtable, state, epsilon)\n", + "\n", + " # Take action At and observe Rt+1 and St+1\n", + " new_state, reward, terminated, truncated, info = env.step(action)\n", + "\n", + " actions.append(int(action))\n", + " rewards.append(float(reward))\n", + " total_steps += 1\n", + "\n", + " # Update Q(s,a):= Q(s,a) + lr [R(s,a) + gamma * max Q(s',a') - Q(s,a)]\n", + " Qtable[state][action] = Qtable[state][action] + learning_rate * (reward + gamma * np.max(Qtable[new_state]) - Qtable[state][action])\n", + "\n", + " # Our next state is the new state\n", + " state = new_state\n", + "\n", + " if terminated or truncated:\n", + " episode_end_steps.append(total_steps)\n", + "\n", + " training_summary = {\n", + " \"reward_window_size\": reward_window_size,\n", + " \"average_reward_at_end\": sum(rewards[-reward_window_size:])/reward_window_size,\n", + " \"total_reward\": sum(rewards),\n", + " \"total_steps\": len(actions),\n", + " \"actions\": list(actions),\n", + " \"rewards\": list(rewards),\n", + " \"episode_end_steps\": episode_end_steps,\n", + " }\n", + " \n", + " print(f\"Training completed for {env.spec.id} at {total_steps} steps.\")\n", + " \n", + " return training_summary" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "def evaluate(env, Qtable, n_evaluation_steps=200):\n", + "\n", + " actions, rewards = [], []\n", + " total_steps = 0\n", + " episode_end_steps = []\n", + "\n", + " while total_steps < n_evaluation_steps:\n", + " # Reset the environment at the start of each new episode\n", + " state, info = env.reset()\n", + " episode_end_steps.append(total_steps)\n", + " terminated = False\n", + " truncated = False\n", + "\n", + " while not terminated and not truncated and total_steps < n_evaluation_steps:\n", + " # Choose the action At using greedy policy\n", + " action = greedy_policy(Qtable, state)\n", + "\n", + " # Take action At and observe Rt+1 and St+1\n", + " new_state, reward, terminated, truncated, info = env.step(action)\n", + "\n", + " actions.append(int(action))\n", + " rewards.append(float(reward))\n", + " total_steps += 1\n", + "\n", + " # Our next state is the new state\n", + " state = new_state\n", + "\n", + "\n", + " evaluation_summary = {\n", + " \"average_reward_at_end\": sum(rewards[-25:])/min(25, len(rewards)),\n", + " \"total_reward\": sum(rewards),\n", + " \"total_steps\": len(actions),\n", + " \"actions\": list(actions),\n", + " \"rewards\": list(rewards),\n", + " \"episode_end_steps\": episode_end_steps,\n", + " }\n", + " \n", + " print(f\"Evaluation completed for {env.spec.id} at {total_steps} steps.\")\n", + " \n", + " return evaluation_summary" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "frozenlake = gym.make(\"FrozenLake-v1\", is_slippery=False)\n", + "frozenlakecustom = gym.make(\"FrozenLake-v1\", is_slippery=False, desc =['SFFF', 'FHFH', 'FFFH', 'GFFH'])\n", + "twobandits = gym.make(\"BanditTwoArmedHighLowFixed-v0\")\n", + "tenbandits = gym.make(\"BanditTenArmedRandomFixed-v0\")\n", + "cliffwalking = gym.make(\"CliffWalking-v0\")\n", + "\n", + "envs = [frozenlake, frozenlakecustom, twobandits, tenbandits, cliffwalking]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training FrozenLake-v1 with args {'map_name': '4x4', 'is_slippery': False}...\n", + "Training completed for FrozenLake-v1 at 200 steps.\n", + "Training FrozenLake-v1 with args {'map_name': '4x4', 'is_slippery': False, 'desc': ['SFFF', 'FHFH', 'FFFH', 'GFFH']}...\n", + "Training completed for FrozenLake-v1 at 200 steps.\n", + "Training BanditTwoArmedHighLowFixed-v0 with args {}...\n", + "Training completed for BanditTwoArmedHighLowFixed-v0 at 200 steps.\n", + "Training BanditTenArmedRandomFixed-v0 with args {}...\n", + "Training completed for BanditTenArmedRandomFixed-v0 at 200 steps.\n", + "Training CliffWalking-v0 with args {}...\n", + "Training completed for CliffWalking-v0 at 200 steps.\n" + ] + } + ], + "source": [ + "from datetime import datetime\n", + "\n", + "environment_results = []\n", + "\n", + "for env in envs:\n", + " print(f\"Training {env.spec.id} with args {env.spec.kwargs}...\")\n", + " Qtable = initialize_q_table(env.observation_space.n, env.action_space.n)\n", + " results = train(n_training_steps, min_epsilon, max_epsilon, decay_rate, env, Qtable, reward_window_size)\n", + " # train(n_training_steps, min_epsilon, max_epsilon, decay_rate, env, Qtable, reward_window_size)\n", + " # results = evaluate(env, Qtable)\n", + " env_result = {\"env\": f\"{env.spec.id} {env.spec.kwargs}\", \"metrics\": results}\n", + " environment_results.append(env_result)\n", + "\n", + "current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", + "spec = {\"spec\": {\"completion_fns\": [\"incontext_rl/qlearning\"], \"eval_name\": \"incontext_rl.v0\", \"base_eval\": \"incontext_rl\", \"split\": \"v0\", \"created_at\": current_time}}\n", + "final_report = {\"final_report\": {\"environments\": environment_results}}\n", + "\n", + "with open('./logs/qlearning_incontext_rl.log', 'w') as f:\n", + " json.dump(spec, f)\n", + " f.write(\"\\n\")\n", + " json.dump(final_report, f)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Plotting running average reward for FrozenLake-v1\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Plotting running average reward for FrozenLake-v1\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Plotting running average reward for BanditTwoArmedHighLowFixed-v0\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Plotting running average reward for BanditTenArmedRandomFixed-v0\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Plotting running average reward for CliffWalking-v0\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "def plot_running_average(rewards, window_size=10000):\n", + " running_avg = np.convolve(rewards, np.ones(window_size)/window_size, mode='valid')\n", + " plt.plot(running_avg)\n", + " plt.title('Running Average Reward Over Time')\n", + " plt.xlabel('Step Number')\n", + " plt.ylabel('Running Average Reward')\n", + " plt.show()\n", + "\n", + "# Assuming `rewards` is a list of rewards from each episode\n", + "for env in environment_results:\n", + " rewards = env[\"metrics\"][\"rewards\"]\n", + " print(f\"Plotting running average reward for {env['env_id']}\")\n", + " plot_running_average(rewards)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/evals/elsuite/incontext_rl/scripts/run_experiments.sh b/evals/elsuite/incontext_rl/scripts/run_experiments.sh new file mode 100755 index 0000000000..9d8765dc22 --- /dev/null +++ b/evals/elsuite/incontext_rl/scripts/run_experiments.sh @@ -0,0 +1,39 @@ +#!/bin/bash +logdir=./logs +outputdir=./outputs + +timestamp=$(date +%Y%m%d_%H%M%S) +logpathbase=$logdir/$timestamp/ + +mkdir -p ${logpathbase} + +echo Running experiments and logging to $logpathbase +read -p "Enter the number of runs: " num_runs + +set -x # Enable printing of each command before it's executed +# Random baselines +oaieval incontext_rl/random incontext_rl.v0 --record_path ${logpathbase}explanations/random.log +oaieval incontext_rl/random incontext_rl.raw.v0 --record_path ${logpathbase}raw/random.log + +for (( run=1; run<=num_runs; run++ )) +do + echo "Run #$run" + # Use explanations variant + # Direct + oaieval generation/direct/gpt-4-turbo-preview incontext_rl.v0 --record_path ${logpathbase}explanations/gpt-4-turbo-preview_${run}.log + oaieval generation/direct/gpt-3.5-turbo incontext_rl.v0 --record_path ${logpathbase}explanations/gpt-3.5-turbo_${run}.log + + # Raw variant + # Direct + oaieval generation/direct/gpt-4-turbo-preview incontext_rl.raw.v0 --record_path ${logpathbase}raw/gpt-4-turbo-preview_${run}.log + oaieval generation/direct/gpt-3.5-turbo incontext_rl.raw.v0 --record_path ${logpathbase}raw/gpt-3.5-turbo_${run}.log + +done + +echo Done running experiments, all logs in $logpathbase + +echo Producing plots for use_explanations variant, outputs to $outputdir +python plot_experiments.py --log_dir $logpathbase/explanations --out_dir $outputdir/explanations +echo Producing plots for raw variant, outputs to $outputdir +python plot_experiments.py --log_dir $logpathbase/raw --out_dir $outputdir/raw +set +x # Disable printing of each command after they've been executed diff --git a/evals/elsuite/skill_acquisition/eval.py b/evals/elsuite/skill_acquisition/eval.py new file mode 100644 index 0000000000..52c770db7d --- /dev/null +++ b/evals/elsuite/skill_acquisition/eval.py @@ -0,0 +1,428 @@ +import json +import logging +import os +import random +from collections import defaultdict +from typing import Any, Dict, List, Optional, Union + +import evals +import evals.metrics +from evals.api import CompletionFn +from evals.elsuite.skill_acquisition.task_description import TASK_DESCRIPTION +from evals.elsuite.skill_acquisition.utils import ( + PROMPTS, + answer_detected, + get_accuracy, + get_average_bleu_score, + get_average_invalid_retrieval_calls, + get_average_retrieval_calls, + get_average_retrieval_precision, + get_bleu_score, + get_bootstrap_accuracy_std, + get_question_type, + get_std_of_difference, + process_answer, + process_view_instruction, + render_intermediate_prompt, + view_instruction_detected, +) +from evals.eval import SolverEval +from evals.solvers.solver import Solver +from evals.task_state import Message, TaskState + +TARGET_LANGUAGES = ["miskito"] +LESSON_FILE_SUFFIX = "_lessons.jsonl" + +logger = logging.getLogger(__name__) + + +class SkillAcquisition(SolverEval): + def __init__( + self, + completion_fns: List[CompletionFn], + samples_jsonl: str, + target_language: str, + knowledge_base_directory: str, + max_replies: int, + seed: int = 6122023, + n_samples: Optional[int] = None, + *args, + **kwargs, + ): + super().__init__(completion_fns, seed=seed, *args, **kwargs) + + assert ( + target_language.lower() in TARGET_LANGUAGES + ), f"Error: target language must be one of {TARGET_LANGUAGES}" + + self.samples_jsonl = samples_jsonl + self.n_samples = n_samples + self.task_description = TASK_DESCRIPTION.format(target_language=target_language) + self.rng = random.Random(seed) + + # Retrieval-related attributes. + self.knowledge_base_directory = self._prefix_registry_path(knowledge_base_directory) + self.files_available = os.listdir(self.knowledge_base_directory) + self.content_by_file: dict[str, dict] = {} + self.max_replies = max_replies # Used as timeout. + + def eval_sample(self, solver: Solver, sample: Dict, rng: random.Random) -> Dict[str, Any]: + """Runs the appropriate private evaluation function depending on the eval phase: retrieval or non-retrieval. + + Args: + solver (Solver): per-sample solver instantiated in parent. + sample (Dict): input to evaluate on. + rng (random.Random): random number generator, used for reproducibility. + + Returns: + Dict[str, Any]: metrics collected during evaluation. + """ + # since we run two discrete experiments per sample, we have to copy the solver ahead of time + non_retrieval_solver = solver.copy() + retrieval_solver = solver.copy() + non_retrieval_out = self._eval_non_retrieval_sample(non_retrieval_solver, sample) + retrieval_out = self._eval_retrieval_sample(retrieval_solver, sample) + metrics_obj = { + "non_retrieval": non_retrieval_out, + "retrieval": retrieval_out, + } + + evals.record.record_metrics(**metrics_obj) + return metrics_obj + + def _eval_non_retrieval_sample(self, solver: Solver, sample: Dict, *_) -> Dict[str, Any]: + """Evaluates the given sample without using retrieval, ie. using the solver directly. + + Args: + solver (Solver): any compatible solver, instantiated just for this sample. + sample (Dict): input to evaluate on. + + Returns: + Dict[str, Any]: metrics collected during evaluation. + """ + task_state = TaskState( + task_description=self.task_description, + messages=[Message(**msg) for msg in sample["input"]], + ) + + result = solver(task_state) + output = result.output + if answer_detected(output): + answer = process_answer(output) + logger.debug(f"Model answered {answer}") + else: + answer = "NO ANSWER DETECTED" + + picked = evals.record_and_check_match( + prompt=sample["input"], + sampled=answer, + expected=[sample["ideal"]], + ) + + out_obj = { + "prompt": sample["input"], + "raw_output": result.output, + "parsed_output": answer, + "expected": [sample["ideal"]], + "correct": picked is not None, + "bleu": get_bleu_score(sample["ideal"], answer), + "question_type": get_question_type(sample["input"][-1]["content"]), + } + return out_obj + + def _eval_retrieval_sample(self, solver: Solver, sample: Dict, *_) -> Dict[str, Any]: + """Evaluates the given sample using retrieval. The retrieval logic is implemented in the _conversation_loop function. + + Args: + solver (Solver): any compatible solver, instantiated just for this sample. + sample (Dict): input to evaluate on. + + Returns: + Dict[str, Any]: metrics collected during evaluation. + """ + files_available_paths = [ + self.knowledge_base_directory / file for file in self.files_available + ] + assert all([file.exists() for file in files_available_paths]) + task_state = TaskState( + task_description=self.task_description, + messages=[Message(**msg) for msg in sample["input"]], + current_state={"files": files_available_paths}, + ) + + output, metrics = self._conversation_loop(solver, task_state) + + if answer_detected(output): + answer = process_answer(output) + logging.debug(f"Model answered {answer}") + elif output == "Context length exceeded.": + answer = "NO ANSWER DETECTED" + logger.warn("Current interaction exceeded model context length.") + else: + answer = "NO ANSWER DETECTED" + logging.debug(f"Model timed out after {metrics['current_replies']} replies.") + + picked = evals.record_and_check_match( + prompt=sample["input"], + sampled=answer, + expected=[sample["ideal"]], + ) + + out_obj = { + "prompt": sample["input"], + "raw_output": output, + "parsed_output": answer, + "expected": [sample["ideal"]], + "correct": picked is not None, + "bleu": get_bleu_score(sample["ideal"], answer), + "ctx_len_exceeded": output == "Context length exceeded.", + "interaction_timed_out": metrics["current_replies"] >= self.max_replies, + "question_type": get_question_type(sample["input"][-1]["content"]), + "lesson_retrieval_calls": metrics["lesson_retrieval_calls"], + "correct_retrieval_calls": metrics["correct_retrieval_calls"], + "invalid_retrieval_calls": metrics["total_retrieval_calls"] + - metrics["correct_retrieval_calls"], + "total_retrieval_calls": metrics["total_retrieval_calls"], + } + return out_obj + + def run(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]: + samples = self.get_samples() + self.rng.shuffle(samples) + samples = samples[: self.n_samples] if self.n_samples is not None else samples + + results = self.eval_all_samples(recorder, samples) + non_retrieval_results = [result["non_retrieval"] for result in results] + retrieval_results = [result["retrieval"] for result in results] + + baseline_accuracy = get_accuracy(non_retrieval_results) + baseline_std = get_bootstrap_accuracy_std(non_retrieval_results) + + retrieval_accuracy = get_accuracy(retrieval_results) + retrieval_std = get_bootstrap_accuracy_std(retrieval_results) + + delta_accuracy = retrieval_accuracy - baseline_accuracy + + # TODO: decide which metric to report – propagated standard deviation + # from bootstrapping or standard error of the mean estimated from repeats + # of the eval experiments. + delta_std = get_std_of_difference(baseline_std, retrieval_std) + + ctx_len_exceeded_rate = sum( + 1 for result in retrieval_results if result["ctx_len_exceeded"] + ) / len(retrieval_results) + timeout_rate = sum( + 1 for result in retrieval_results if result["interaction_timed_out"] + ) / len(retrieval_results) + + num_translation_samples = len( + [result for result in retrieval_results if result["question_type"] == "translation"] + ) + num_non_translation_samples = len( + [result for result in retrieval_results if result["question_type"] == "non-translation"] + ) + + result = { + "baseline_accuracy": baseline_accuracy, + "baseline_std": baseline_std, + "retrieval_accuracy": retrieval_accuracy, + "retrieval_std": retrieval_std, + "delta_accuracy": delta_accuracy, + "delta_std": delta_std, + "average_retrieval_precision": get_average_retrieval_precision(retrieval_results), + "average_non_retrieval_bleu_score": get_average_bleu_score(non_retrieval_results), + "average_retrieval_bleu_score": get_average_bleu_score(retrieval_results), + "average_retrieval_calls": get_average_retrieval_calls(retrieval_results), + "average_invalid_retrieval_calls": get_average_invalid_retrieval_calls( + retrieval_results + ), + "ctx_len_exceeded_rate": ctx_len_exceeded_rate, + "timeout_rate": timeout_rate, + "num_samples": len(retrieval_results), + "num_translation_samples": num_translation_samples, + "num_non_translation_samples": num_non_translation_samples, + } + + return result + + def _view_content( + self, + file_name: str, + section_title: str = None, + sections_visible_to_model: dict[str, set] = defaultdict(set), + sections_viewed: dict[str, set] = defaultdict(set), + ) -> tuple[str, dict[str, set], dict[str, set]]: + """Views content from a JSONL file in the knowledge base. + If a section is provided, only the contents of that section are returned. + If no section is specified, the function returns the table of contents of the file. + + Args: + file_name (str): Name of the file. Full directory prefixed automatically. + section_title (str, optional): Name of the section to view. Defaults to None. + sections_visible_to_model (dict[str, set], optional): Dictionary of sections visible to the model. Defaults to {}. Updated in-place. + sections_viewed (dict[str, set], optional): Dictionary of sections viewed by the model. Defaults to {}. Updated in-place. + + Returns: + tuple(str, dict[str, set], dict[str, set]): A tuple of + the content of the section (if specified) and + the updated dictionaries of sections visible to and viewed by the model. + """ + # TODO: more general file format. + + if file_name in self.content_by_file: + file_content_by_section = self.content_by_file[file_name] + else: + # This should never occur, but if it does it should stop the eval from running. + if not os.path.exists(self.knowledge_base_directory / file_name): + raise ValueError( + f"File {self.knowledge_base_directory / file_name} does not exist." + ) + + file_content_by_section = {} + with open(self.knowledge_base_directory / file_name, "r") as f: + for line in f: + line_dict = json.loads(line) + file_content_by_section[line_dict["title"]] = line_dict["content"] + self.content_by_file[file_name] = file_content_by_section + + if section_title is None: + sections = set(file_content_by_section.keys()) + sections_visible_to_model[file_name] = sections + sections_viewed[file_name].add("Table of Contents") + + return ( + f"Table of contents for {file_name}: {sections}.", + sections_visible_to_model, + sections_viewed, + ) + + sections_viewed[file_name].add(section_title) + return file_content_by_section[section_title], sections_visible_to_model, sections_viewed + + def _conversation_loop( + self, solver: Solver, task_state: TaskState + ) -> tuple[str, Dict[str, int]]: + """Maintains a conversation with the model until it outputs an answer or times out. + The model may request to read a file or a section of a file from the knowledge base. + + Args: + solver (Solver): any compatible solver, instantiated just for this sample. + task_state (TaskState): current task_state, which additionally contains a list of knowledge base files in `current_state`. + + Returns: + tuple[str, Dict[str, int]]: a tuple of the model's output and a dictionary of metrics collected during the conversation. + """ + output = "" + + # Not all retrieval calls are valid, e.g. if the file doesn't exist. + # These two metrics are analogous to an instruction-following rate. + metrics = { + "lesson_retrieval_calls": 0, + "correct_retrieval_calls": 0, + "total_retrieval_calls": 0, + "current_replies": 0, + } + sections_visible_to_model: dict[str, set] = defaultdict(set) + sections_viewed: dict[str, set] = defaultdict(set) + consecutive_instruction_failures = 0 + + while not answer_detected(output) and metrics["current_replies"] < self.max_replies: + if metrics["current_replies"] == 0: + # Beginning of the conversation, prepare instructions. + task_state.task_description = ( + task_state.task_description + + "\n\n" + + PROMPTS["retrieval_instructions"].format(list_of_files=self.files_available) + ) + if len(sections_viewed.items()) > 0: + intermediate_prompt = render_intermediate_prompt(sections_viewed) + task_state.messages += [Message(role="system", content=intermediate_prompt)] + + output = solver(task_state).output + task_state.messages += [Message(role="assistant", content=output)] + metrics["current_replies"] += 1 + + if view_instruction_detected(output) or answer_detected(output): + consecutive_instruction_failures = 0 + + if view_instruction_detected(output): + file, section = process_view_instruction(output) + metrics["total_retrieval_calls"] += 1 + + if file.endswith(LESSON_FILE_SUFFIX): + metrics["lesson_retrieval_calls"] += 1 + + # Handle any errors by logging and re-prompting the model. + if file not in self.files_available: + task_state.messages += [ + Message( + role="system", + content=PROMPTS["wrong_file"].format( + file=file, knowledge_base=self.files_available + ), + ) + ] + logger.debug( + f"Model tried to view {file}, which does not exist in the knowledge base:\n{json.dumps(self.files_available, indent=4)}." + ) + continue + + if section is not None and section not in sections_visible_to_model[file]: + task_state.messages += [ + Message( + role="system", + content=PROMPTS["wrong_section"].format( + file=file, + section=section, + table_of_contents=sections_visible_to_model[file], + ), + ) + ] + logger.debug( + f"Model tried to view section {section} in file {file}, which does not exist.\nAvailable sections are {json.dumps(list(sections_visible_to_model[file]), indent=4)}." + ) + continue + + # If no errors, view the content and update the task state. + content, sections_visible_to_model, sections_viewed = self._view_content( + file, section, sections_visible_to_model, sections_viewed + ) + task_state.messages += [ + Message( + role="system", + content=PROMPTS["present_content"].format( + file=file, + section=section if section is not None else "Table of Contents", + content=content, + ), + ), + ] + metrics["correct_retrieval_calls"] += 1 + if section is None: + logger.debug(f"Model viewed table of contents for file {file}: {content}") + else: + logger.debug(f"Model viewed section {section} in file {file}.") + elif not answer_detected(output): + if consecutive_instruction_failures >= 3: + return "Model failed to follow instructions.", metrics + + consecutive_instruction_failures += 1 + logger.debug( + f"Model output did not contain a view instruction or an answer: {output}" + ) + + # Flag & move onto next sample if context length exceeded. + if ( + "'code': 'context_length_exceeded'" in output + or "Please reduce your prompt; or completion length" in output + ): + return "Context length exceeded.", metrics + + task_state.messages += [ + Message( + role="system", + content="Your output did not contain a view instruction or an answer. Please try again.", + ) + ] + + return output, metrics diff --git a/evals/elsuite/skill_acquisition/readme.md b/evals/elsuite/skill_acquisition/readme.md new file mode 100644 index 0000000000..2d5a8fafcb --- /dev/null +++ b/evals/elsuite/skill_acquisition/readme.md @@ -0,0 +1,64 @@ +# Skill acquisition + +This eval tests models' ability to learn a skill with minimal human involvement. In the initial release, models are evaluated on questions related to the [Miskito language](https://en.wikipedia.org/wiki/Miskito_language). Some samples are translation and others are language manipulation exercises. + +## Usage +Run with: +```bash +oaieval skill_acquisition.miskito +``` + +Where the solver can be any generation solver in `evals/registry/solvers/defaults.yaml`, eg. `generation/cot/gpt-3.5-turbo-16k`. + +## Evaluation process +Every time the eval is run, the model is evaluated twice. The first time, it answers the question directly using whatever prompting technique is executed by the solver you choose. The second time the model runs in a loop, interacting with an interface which gives it access to a knowledge base. The knowledge base contains text files, some of which are relevant for answering the question, while others are unrelated. If models can use this interface to increase their performance on the task, we can say that they've improved or acquired their language translation and manipulation skills. + +## Prompts +See `skill_acquisition/utils.py` to review/adjust the prompts used in this eval. + +## Datasets + +The dataset is generated from [this language course](https://en.wikibooks.org/wiki/Miskito), which comprises 229 questions. We further split this into manipulation-only (`miskito_test_manipulation.jsonl`) and translation-only (`miskito_test_translation.jsonl`) subsets. + +## Variants + +We test zero-shot and few-shot prompting techniques on the dataset: + +| Dataset | Zero-shot | Few-shot | +| --------- | -------- | -------- | +| Miskito | `skill_acquisition.miskito.zero-shot.full`|`skill_acquisition.miskito.few-shot.full`| + +The `full` in this case refers to the size of the dataset – there are also variants for testing where only 5 examples are considered, called `dev5`. For full details, look at `evals/registry/skill_acquisition/skill_acquisition.yaml`. + +For the few-shot setting, use the eval-specific solvers in `evals/registry/solvers/skill_acquisition.yaml` to avoid train/test leakage. + +## Token Usage Estimates + +Below is a rough estimate of the total number of tokens consumed by some variations the eval, including both input and output tokens: + +| Model | Solver | Prompt tokens | Completion tokens | Total tokens +| --- | --- | --- | --- | --- | +| gpt-3.5-turbo | direct | 1,000,000 | 23,000 | 1,050,000 | +| gpt-3.5-turbo | cot | 930,000 | 120,000 | 1,050,000 | +| gpt-3.5-turbo | fewshot | 450,000 | 9,600 | 460,000 | +| gpt-3.5-turbo-16k | direct | 1,400,000 | 24,000 | 1,500,000 | +| gpt-3.5-turbo-16k | cot | 2,000,000 | 120,000 | 2,100,000 | +| gpt-3.5-turbo-16k | fewshot | 610,000 | 10,000 | 620,000 | +| gpt-4-base | direct | 1,800,000 | 420,000 | 2,200,000 | +| gpt-4-base | cot | 4,700,000 | 890,000 | 5,600,000 | +| gpt-4-base | fewshot | 1,400,000 | 320,000 | 1,700,000 | +| gpt-4-1106-preview | direct | 1,700,000 | 100,000 | 1,800,000 | +| gpt-4-1106-preview | cot | 1,600,000 | 99,000 | 1,700,000 | +| gpt-4-1106-preview | fewshot | 1,700,000 | 95,000 | 1,800,000 | +| gpt-4-32k | direct | 1,800,000 | 80,000 | 1,900,000 | +| gpt-4-32k | cot | 2,700,000 | 180,000 | 2,900,000 | +| gpt-4-32k | fewshot | 190,000 | 6,000 | 190,000 | + +## Version History +v0: Initial version released + + +## Contribution statement + +Eval design, implementation, and results evaluation were primarily conducted by Andrei Alexandru. Giulio Starace was responsible for code reviews throughout the implementation process, along with fine-grained feedback on the project in general. Additional guidance was provided by (alphabetically by last-name) Steven Adler, James Aung and Chan Jun Shern, who scoped and managed the broader research project, including input on evaluation design, results analysis, and interpretation. + diff --git a/evals/elsuite/skill_acquisition/scraping/human_rights.html b/evals/elsuite/skill_acquisition/scraping/human_rights.html new file mode 100644 index 0000000000..c6d49a320c --- /dev/null +++ b/evals/elsuite/skill_acquisition/scraping/human_rights.html @@ -0,0 +1,2839 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + OHCHR | Universal Declaration of Human Rights - Miskito + + + + + + + + + + + + + Skip to main content + + +
+
+ +
+ + + +
+ +
+ +
+
+ + + + + + +
+ +
+
+ + + +
+
+ + +

Universal Declaration of Human Rights - Miskito

+
+

SOURCE

+

Comité para la Defensa de los Derechos Humanos, Honduras

+
+
+ + +
+
+
+ +
+
+
+
+ + +
Miskito
+ +
+
+
Language Profile
+ + +

TOTAL SPEAKERS

160,000 (1982)

USAGE BY COUNTRY (OFFICIAL LANGUAGE)

Home Speakers: Nicaragua, Honduras

BACKGROUND

It belongs to the Misumalpan family (Macro-Chibchan subgroup) and is spoken by 11,000 people in Honduras and over 150,000 people in Nicaragua. It is a language of trade in Honduras, whereas it is widely used in Nicaragua, both in primary schools and among older people.

+ +
+
+
+
+
+
+ + +

Upla sut Raitka nani ba Tasba aiska laka ba Bapuia

+

Asla Takanka tara ba Naha Upla sut Raitka nani ba Tasba aiska laka ba Bapuia,

+

Sut lukanka baku, upla sut, kantri nani sut, trai kaikaia; baku, kKumi bani, dakni nani bani, nahara kat luki, tabaikaia, Smalkanka bak, kul nani bak, naha Raitka nani ba, Bara, Prika laka naniba, pramis kum Dauki, kantri laka nani bilkak, at apia, tasba, aiska, laka, nani bilkak, atsa, yaka kakaira takaia bara kulkaia wan kantri nani bui.

+

[Preamble]

+

Kulkanka 1

+

Upla sut ba kulkanka lakara, airaitka nanira bara pri, sin, aikuki, baku takisa. Bamna sins laka bri baku, lukanka bain pri baku aimuihni lakara, pana pana tabaikan kaiasa.

+

Kulkanka 2

+

Naha lakara pas taurá wisa, upla baniba airaitka brisa, bara sin, pri san: nisanka kulkras, taya maplika kulkras, mairin sapa waikna sapa kulkras, bila aisanka kulkras, ani gadkara mayuni sapa kulkras, aipulitikka lukanka ba, apia sa kaka, dia dia dukya kabia sin kulkras, wan tasbaya wina sapa, yuya kira sapa, wan tasbayara baikan sapa, apia kaka, dia dia walanatkara sin kulkras kira.

+

Baku sin wan kantri pulitik ka laka bui sin, wan kantri laka nani bui sin, apia kaka, tasba aiska laka mita sin, apia laka, ani tasbayara iwi ba bui sin upla kumira sin mayara kulkan kaia apia sa. Bamna kantri wala nani natkara iwi bara, kankri wala laka munhtara iwiba, alba laka natkara nanira iwiba sin saura baku kulkan kaia apia sa.

+

Kulkanka 3

+

Upla sut ba airaitka brisa airayaka kum brieia pri lakara iwaia upla baku, aimain kira kaia.

+

Kulkanka 4

+

Sip apia upla kum alba lakara kaia, bamna, baha natkara yus munan kaia sin, kan baha laka apu sa.

+

Kulkanka 5

+

Upla kumi sinra, sip apia sa uba saura munan kaia, silak mankan kaia, swira pask an, an upla apia baku munaia.

+

Kulkanka 6

+

Upla bani ai raitka brisa anira kabia sin lâkat upla baku kulkan kaia.

+

Kulkanka 7

+

La mawanra sut ba kumi kulkan sa, bara kumi bani la ba bui aikuki baku main kaikisa, upla sut ba airaitka brisa aikuki baku main kaikan kaia, wala nani bui mayara kulkan kabiara sin; naha laka tara bapanna kulkras baku.

+

Kulkanka 8

+

Upla sut la airaitka brisa tabaikanka uplika pain kum brikaia, wan kantri laka mawanra, baku mika sipsa airaitka nani kulkras munbia sin, la kulkan ka ba brih wabia.

+

Kulkanka 9

+

Upla kumi sin, sip apia sa ban kakalhni silak ra mangki saura munaia.

+

Kulkanka 10

+

Upla sutba, airaitka brisa, wala nani baku, upla sut mawanra, an la kum taibanka apu kira bui aiturka ba walan kaia. Baku mika airaitka nani, bara, witin daukaia dukia nani ba marikan kabia,m apia kaka, dia dia saurka dukiara munansa kapa sip kabia laki kaikaia.

+

Kulkanka 11

+
    +
  1. Upla bani ba, dia dia saurka dudiara laura lulkansa bara, airaitkabrisapas taura aiturkaba aisaia sip kabia, kau taibi munras bara, la tankaba kat, baku mika, bilka nani yâban kaiasa, upla sut mawanra, bapi buaia sip kaia.
  2. +
  3. Upla kumira sin, saura munan kaia apiasa pât kum dukiara, daukanka ba puyara, saura pali kulkan apia sa kaka, wan kantry raitka nani bui sin, ba wisi sin, saurka uba tara kulkan kaia apia sa, baha pât kaba dukiara.
+

Kulkanka 12

+

Upla bani Rayakaba, Wala bui turban kaia apiasa, Tâika nanira kabia sinm, watla bilara kabia sin, dukya nanira kabia sin ki, apia kaka, nina sauhkaia, rispik ka alahbaia; upla bani ba, airaitka brisa, baha nani saurka mapara la bui main kaikaia.

+

Kulkanka 13

+
    +
  1. Upla bani ba airaitka brisa, pri pali taukaia bara, kantri bilara tasba kum bri kaia, iwaia lahma.
  2. +
  3. Upla bani ba, airaitka brisa aitasbaya wina taki waia, bara, kli balaia, dimaia sin.
+

Kulkanka 14

+
    +
  1. Bankra dia dia patka dukiara nina blikisa kaka, upla bani ba, airaitka brisa, natka kumpliki, tasba wala kum distika makabaia, wala nani baku auya pah iwikaia dukiara.
  2. +
  3. La nani bui pat bahki nani kulkanba dukiara ban sin Tasba Aiska Asla Takanka brinka nani bara lukanka ta nani ba kulkras kira, naha Raitkana makabaia.
+

Kulkanka 15

+

Upla baniba, airaitka brisa kantri kumra iwikaia apia kaka, kantri walara iwaia lukankaba yabalka prakan kaiasa.

+

Kulkanka 16

+
    +
  1. Waikna bani, mairin bani ba, airaitka brisa pyua alkansa bara, nisanka kulkras, ani kantrikara iwiba kulkras, ani Gadkara mayuniba kulkras kira, sipsa marittakaia; sahwaia sin, baku mika, marit takansa bara, apia, mahka wal swibia sinki wal baku iwaiasa.
  2. +
  3. Marit laka daukan, kabia marit uplika naniba, aikupya wilinkira sakaka.
  4. +
  5. Panli laka, upia sut wina kau yamnika bak sakan dukia kumsa, bamna, upla sut bui, Gabament bui sin main pali kaikan kaiasa.
+

Kulkanka 17

+
    +
  1. Upla bani ba airairka brisa aidukia pawaia lahma brikaia, yakan lakara, bamna, upla wala nani aikuki asla lakara sin.
  2. +
  3. Upla kumi ra sin, Aidukia Pawaia lahma pat taki ba, yabalka prakaia.
+

Kulkanka 18

+

Upla bani ba, ai raitka brisa pri lakara dia dia lukaia, lukanka pain nani brikaia, bara, ani ani Gadkara lukaia, naha raitkana ra luki sipsa Gad Wala nani ra lukaia, upla nanira marikaia sipkabia; yakan kabia, upla wala sin, Aikulkanka aikuki kabia sin; upla nani mawanra, prakan ra dauki kaia, lahma, kulkaia, lahma, bara, laki kaikaia mata kabia sin.

+

Kulkanka 19

+

Upla sutba, airaitka brisa prikaia ailukankara aisankara; naha raitkara aisisa, dia dia lukanka dukiara, upla kumira sin warbras kaia sa tanka plikaia sip kaia, dia dia turiba nu kaia, bara, wala nanira maisapakaia, kantry ka kulkras, dia dia bilkak kat kabia sin.

+

Kulkanka 20

+
    +
  1. Upla bani ba, airaitka brisa pri lakara aslatakanka kum daukaia, bara, aslatakanka lamni laka kat brikaia.
  2. +
  3. Upla kumi sin, sip apia sa, taibi munankaia asla takanka kumra tilara kaia.
+

Kulkanka 21

+
    +
  1. Upla sutba, airaitka brisa, gabament dukia tilara kaia, ban sin, wala nanira tabaikaia baha nani tilara kabia.
  2. +
  3. Upla sutba airaitka brisa, wala nani aikuki baku Gabament Warkka nani tilara kaia.
  4. +
  5. Tawan aiska brinka ba, upla sut karhnika sa gabament tanira; naha brinka ba, gabament bani mangkisa bara, klir lakisa, lulkaia laka ba kat kulki, aikuki baku, bara, upla bani aikupya laka kat, ban natka wal nani ni daukkbia sin.
+

Kulkanka 22

+

Upla bani ba, upla baku airaitka brisa, main kaikankaia, baku sin, gabament tabaikanka baku, tasba aiska buisin asla takanka nani bilka brisa bara, gabament dukia nani sut kulki, pawaia natka nani, asla takaia nani, bara, aikulkanka nani sutba brin kaiasa, baku mika, upla baku ailukanka kat, bara, pri pali pawaia sip kabia.

+

Kulkanka 23

+
    +
  1. Upla bani ba airaitka brisa warkka kum brikaia, bara, aikupya pahkira wark plikaia, wala nani baku kaia, wark pain brikaia, bara, warka apu pyuara tabaikan kaiasa.
  2. +
  3. Upla sutba airaitka brisa wala mita mayara kulkankaia apiasa, aiwarkka daukiba baku mâna sin baku kaia su.
  4. +
  5. Upla wark taki naniba airaitka aprisa aimana kum brikaia, mana sin painkaia, baku sip kabia aitaya nani main kaikaia, upla baku iwaia, ban sin bilka sa kaka natka wala nani pliki mainka kaikan kaiasa.
  6. +
  7. Upla bani ba airaitka risa aslatakanka dakni paskaia bara tilara dimaia sin, aibrinka nani dukyara aiklabaia mata.
+

Kulkanka 24

+

Upla bani ba ai raitka brisa ris briaia, riska lilika briaia ai wark ka pyua kum brikaia, bara, baku sin ris pyua yari nani sin, ai mana wal.

+

Kulkanka 25

+
    +
  1. Upla bani ba, airaitka brisa iwaia natka pain kum brikaia, baku, sip kabia witin, bara, aitaika nani sin, siknis nani luhakaia; kau purara ban kulkan kaiasa: plun ba, praka, utla ba, sika nani yabaiaba, upla baku mainka kaikaia ba; baku sin, airaitka brisa wark apu sa pyuara, mainka kaikan kaia, siknis sa pyuara saua sakan sa bara, pyarka takansa bara, almuk takan sa bara, ban sin dia dia bui kra aidukia nani sut sauhki tikan sa bara, tabaikan karia.
  2. +
  3. Mairin ba, kwihra sa bara, baikan pyuara sin airaitka brisa main kaikan kaia; bara, dia dia brinka nani sut yâban kaia, ani tuktika, marit laka kat kulki baikan kabia, apia, tnayara baikan kabia sin airaitka brisa wal baku main kaikan kaia.
+

Kulkanka 26

+
    +
  1. Upla bani ba, airaitka brisa aaisinska kwakaia, smalkanka ba pri natkara kaiasa, ulbaia ba pan, aisikaikaiaba pan. Baku sin, karhna munan kaiasa ulbaia ba, bara aisikaikaiaba, lan takaia; lila kulka naniba, sut lahma kaiasa baku sin, kul nanira dimaia ba sip takan kaiasa sut lahma, kumi bani daukan kaba kaiki.
  2. +
  3. Kul smalkanka brinka kabia; upla ba; upla baku lukanka brikaia dukiara, smalkaia, baku sin, upla bani airaitka ba kulkaia, bara, upla bani aiprika laka kum kum bri nmaniba kulkaia dukiara smalkan kaia; tanka pain briaia bra, aidahra pain walaia, bara, pana laka tasba wala nani aikuki bara, indian nani sut aikuki kau taura kulkan kaiasa, baku si kupya kumi laka, upla sut mata, tasba aiska asla takanka daukiba ta baikan kaiasa.
  4. +
  5. Tuktan nani aisika bani pa, sip kabia ailuhpya dia a dia lan takaia ba, witin pali pliki yabaia.
+

Kulkanka 27

+
    +
  1. Upla bani ba, pri lakara aitasbaya lukanka laka nani tilara, kaia, baku sin paskanka nani tilara, bara, sins laka tara nani pawanka dilara kaia, ban sin baha lilika briaia.
  2. +
  3. Upla bani ba, airaitka brisa airispik ka laka ba, bara, aidukia nani ba sin, main kaikan kaia, witin aisinska tihukani, aiulbanka nani bak kra, apia, aipaskanka nani bak brisa kaka.
+

Kulkanka 28

+

Upla bani ba, airaitka brisa, tasba aiskara, bara, aitasbayara sin la kat, bara, wapni laka kata iwaia; naha laka bapan na; upla nani raitka ba kulkaia, bara, pri laka ba kulkaia nani ba, sut alkaia mata.

+

Kulkanka 29

+
    +
  1. Upla sut bui ai tawan kara, rispik ka ba yaban kaiasa bara baman upla baku ai auya pah pawisa.
  2. +
  3. Ai raitka nani ba , kulki, bara, ai prika lakaba wal, ai auya pah kaiasa kaka, upla bani ba la nani bapanba yabalka kat wapaia sa, baku mika, upla wala nani raitka, bara, prika laka aniba kaikaia, kulkaia sin, baku rispik ka yaia la kat iwaia, bara pana pana kupya pliki natka nani iwaia sa, kaka.
  4. +
  5. Naha raitka naniba, bara prika laka nani ba kulkan kaia apiasa, Tasba Aiska Asla Takanka lukanka mapara sa kaka.
+

Kulkanka 30

+

Naha laka bapanna, gavament ra kabia, dakni kumra kabia, upla kumra kabia, bilka ya bansa lukan kaia apia sa, raitka nani, bara, prika laka nani naha lakara aisan na, alki taibi munaia upla wala nanira.

+ +
+
+
+ + + + + + + + +
+
+ +
+ +
+
+
+ +
+ + +
+
+ + + + + + + + + + + + + diff --git a/evals/elsuite/skill_acquisition/scraping/scrape_distractor_articles.py b/evals/elsuite/skill_acquisition/scraping/scrape_distractor_articles.py new file mode 100644 index 0000000000..93e248b107 --- /dev/null +++ b/evals/elsuite/skill_acquisition/scraping/scrape_distractor_articles.py @@ -0,0 +1,96 @@ +# %% +import json +import re + +import requests +from bs4 import BeautifulSoup +from markdownify import markdownify as md + +articles_to_scrape = [ + "https://en.wikipedia.org/wiki/Mosquito", + "https://en.wikipedia.org/wiki/Mosquito_Coast", + "https://en.wikipedia.org/wiki/Nicaragua", + "https://en.wikipedia.org/wiki/Honduras", + "https://en.wikipedia.org/wiki/Miskito_language", + "https://en.wikipedia.org/wiki/Miskito_people", +] +dirpath = "evals/registry/data/skill_acquisition/distractor_articles/" + + +def clean_soup(content): + for infobox_tag in content.find_all("table", class_="infobox"): + infobox_tag.decompose() + for figure_tag in content.find_all("figure"): + figure_tag.decompose() + for style_tags in content.find_all("style"): + style_tags.decompose() + reflist_div = '
") + + sections = {} + for heading_text in headings: + if "" not in heading_text: + sections["Introduction"] = clean_heading_text(heading_text) + continue + span = heading_text[: heading_text.index("")] + heading_title = BeautifulSoup(span, "html.parser").contents[0].contents[0] + text = heading_text[heading_text.index("") + 5 :] + if heading_title not in ["References", "See also", "External links", "Footnotes"]: + sections[heading_title] = clean_heading_text(text) + + article_title = article.split("/")[-1] + + print(f"Scraped {article_title} successfully. Headings: {sections.keys()}\n") + filename = f"{article_title.lower()}.jsonl" + + with open(dirpath + filename, "w") as f: + for k, v in sections.items(): + f.write(json.dumps({"title": k, "content": v}, ensure_ascii=False) + "\n") + +# Separate code to scrape human rights article, as it's in a different format. +with open("human_rights.html", "r") as f: + html = f.read() + +soup = BeautifulSoup(html, "html.parser") +content = soup.find("div", class_="migrated-content") +md_content = md(str(content)).replace("\xa0", " ").replace("\u3000", " ") + +with open(dirpath + "human_rights_miskito.jsonl", "w") as f: + f.write( + json.dumps( + {"title": "Declaration of Human Rights in Miskito", "content": md_content}, + ensure_ascii=False, + ) + + "\n" + ) diff --git a/evals/elsuite/skill_acquisition/scraping/scrape_miskito.py b/evals/elsuite/skill_acquisition/scraping/scrape_miskito.py new file mode 100644 index 0000000000..697b5667cd --- /dev/null +++ b/evals/elsuite/skill_acquisition/scraping/scrape_miskito.py @@ -0,0 +1,135 @@ +# %% +import json + +import bs4 +import requests +from bs4 import BeautifulSoup +from markdownify import markdownify as md + +# TODO: make sure italicised text is crawled properly and that hints are excluded from answers. +# TODO: Split any multi-part questions into individual questions. + +miskito_base_url = "https://en.wikibooks.org/wiki/Miskito/Lesson_{idx}" + + +def process_practice_section_div(practice_div: bs4.element.Tag): + tds = practice_div.find_all("td") + instructions = ( + md(str(tds[1])) + .replace("*", "") + .replace("|", "") + .strip() + .replace("What do these mean?", "Translate to English:") + .replace("What do these sentences mean?", "Translate to English:") + ) + question_text = tds[2] + questions = question_text.find_all("li") + questions = [str(q.contents[0]) for q in questions] + answer_text = tds[3] + answers = answer_text.find_all("li") + answers = [str(a.contents[0]) for a in answers] + return instructions, questions, answers + + +def extract_toc_sections(content: bs4.element.Tag): + toc = content.find_all("div", class_="toc")[0] + lis = toc.find_all("li", class_="toclevel-1") + lis = [li.find_all("span", class_="toctext")[0].contents[0] for li in lis] + + lis = [md(str(li)).strip().replace("*", "") for li in lis] + return lis + + +def process_miskito_page(): + qa_pairs_by_lesson = {} + articles_without_qa_pairs = [] + for idx in range(1, 11): + response = requests.get(miskito_base_url.format(idx=idx)) + soup = BeautifulSoup(response.text, "html.parser") + content = soup.find("div", class_="mw-content-ltr mw-parser-output") + + # Extract the question-answer pairs. + divs_with_specific_style = content.find_all( + "div", style=lambda value: value and "width:300px; float:right;" in value + ) + lesson_qa_pairs = [] + for i, div in enumerate(divs_with_specific_style): + if i == 0 and idx == 1: # First section of first lesson is not in the same format. + instructions = "Translate to English:" + questions = div.find_all("ul")[0].find_all("li") + questions = [str(q.contents[0]) for q in questions] + answers = div.find_all("ul")[1].find_all("li") + answers = [str(a.contents[0]) for a in answers] + lesson_qa_pairs += [ + {"question": q, "answer": a, "instructions": instructions} + for q, a in zip(questions, answers) + ] + continue + instructions, questions, answers = process_practice_section_div(div) + for q, a in zip(questions, answers): + lesson_qa_pairs += [{"question": q, "answer": a, "instructions": instructions}] + qa_pairs_by_lesson[f"lesson_{idx}"] = lesson_qa_pairs + + # Remove them from the page and store the page contents. + for div in divs_with_specific_style: + div.decompose() + + articles_without_qa_pairs += [content] + + return qa_pairs_by_lesson, articles_without_qa_pairs + + +# %% +# Write to file: all questions by lesson, and all questions in evallib format. +qa_pairs_by_lesson, clean_articles = process_miskito_page() +qa_by_lesson_file = "miskito_qa_pairs_by_lesson.jsonl" + +with open(qa_by_lesson_file, "w") as f: + for lesson, qa_pairs in qa_pairs_by_lesson.items(): + f.write(json.dumps({"lesson": lesson, "qa_pairs": qa_pairs}) + "\n") + +miskito_qa = "miskito_qa.jsonl" +with open(miskito_qa, "w") as f: + for lesson, qa_list in qa_pairs_by_lesson.items(): + for qa_dict in qa_list: + instructions = qa_dict["instructions"][:-1] + ": " + f.write( + json.dumps( + { + "input": [{"role": "user", "content": instructions + qa_dict["question"]}], + "ideal": qa_dict["answer"], + }, + ensure_ascii=False, + ) + + "\n" + ) +# %% +as_text = [str(a).split("

")[1:] for a in clean_articles] +sections_by_heading = {} +for article in as_text: + for heading in article: + hsoup = BeautifulSoup(heading, "html.parser") + heading_name = ( + md(str(hsoup.find("span", class_="mw-headline").contents[0])).replace("*", "").strip() + ) + hsoup.find("span", class_="mw-editsection").decompose() + content = ( + md(str(hsoup)) + .strip() + .replace("*", "") + .replace("|", "") + .replace("What do they mean?", "") + .replace(" --- ", "") + .replace("\u2003", " ") + .replace(" ", " ") + ) + content = content.split(" Study ")[1] if "Study " in content else content + sections_by_heading[heading_name] = content.strip() + +sections_by_heading +# %% +file = "lessons_no_exercises.jsonl" +with open(file, "w") as f: + for heading, content in sections_by_heading.items(): + f.write(json.dumps({"title": heading, "content": content}, ensure_ascii=False) + "\n") +# %% diff --git a/evals/elsuite/skill_acquisition/scripts/make_plots.py b/evals/elsuite/skill_acquisition/scripts/make_plots.py new file mode 100644 index 0000000000..01eab83412 --- /dev/null +++ b/evals/elsuite/skill_acquisition/scripts/make_plots.py @@ -0,0 +1,204 @@ +import argparse +import os +from pathlib import Path + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns + +from evals.utils import log_utils + +PLOT_TITLES_BY_METRIC = { + "overall_accuracy": "Accuracy", # ie. both retrieval and non-retrieval in one plot + "baseline_accuracy": "Baseline accuracy (non-retrieval)", + "retrieval_accuracy": "Retrieval accuracy", + "average_retrieval_precision": "Average retrieval precision", + "average_non_retrieval_bleu_score": "Average non-retrieval BLEU score", + "average_retrieval_bleu_score": "Average retrieval BLEU score", + "average_retrieval_calls": "Average retrieval calls", + "average_invalid_retrieval_calls": "Average invalid retrieval calls", + "bleu_score": "BLEU score", + "correct_call_rate": "Correct call rate", + "invalid_call_rate": "Invalid call rate", + "timeout_rate": "Timeout rate", + "ctx_len_exceeded_rate": "Context length exceeded rate", +} + +UNIT_METRICS = set( + ["correct_call_rate", "invalid_call_rate", "timeout_rate", "ctx_len_exceeded_rate"] +) + + +def extract_metrics(datadir: Path) -> pd.DataFrame: + df_rows = [] + for path, results in sorted(list(log_utils.get_final_results_from_dir(datadir).items())): + spec = log_utils.extract_spec(path) + solver_path = Path(spec["completion_fns"][0]) + model = solver_path.name + solver = solver_path.parent.name + # Remove root section of path, which is the eval name + solver_path = solver_path.relative_to(solver_path.parts[0]) + df_rows.append({"solver": solver, "model": model, **results}) + df = pd.DataFrame(df_rows) + + return df + + +def make_plot( + df: pd.DataFrame, + outpath: Path, + metric="baseline_accuracy", + min_ylim=0, + max_ylim=0.08, + dataset="miskito", +): + plt.figure() + sns.set_theme(style="whitegrid") + # Calculating mean and SEM + grouped = df.groupby(["model", "solver"])[metric].agg(["mean", "sem"]).reset_index() + + def compute_sem(x): + sem = x.std() / (len(x) ** 0.5) + sem2 = sem * 2 # 95% confidence interval + return (x.mean() - sem2, x.mean() + sem2) + + # Plotting + sns.set(style="whitegrid") + sns.barplot(x="model", y="mean", hue="solver", data=grouped, errorbar=compute_sem, capsize=0.1) + plt.xticks(rotation=30, ha="right") + plt.ylim(min_ylim, max_ylim) + + # Some of the metrics are in [0, 1]. + if metric in UNIT_METRICS: + plt.ylim(0, 1) + + plt.title(PLOT_TITLES_BY_METRIC[metric] + f" on {dataset.capitalize()} Q&A dataset") + plt.xlabel("Model") + plt.tight_layout() + plt.savefig(outpath) + plt.close() + + +def make_side_bar_plot( + df: pd.DataFrame, + outpath: Path, + metric="overall_accuracy", + min_ylim=0, + max_ylim=0.1, + dataset="miskito", +): + if metric == "overall_accuracy": + df_clean = df[["model", "solver", "baseline_accuracy", "retrieval_accuracy"]] + elif metric == "bleu_score": + df_clean = df[ + ["model", "solver", "average_non_retrieval_bleu_score", "average_retrieval_bleu_score"] + ] + + fig, ax = plt.subplots(figsize=(10, 5)) + # df_clean = df_clean.drop(columns=["solver"]) + df_clean.set_index(["model", "solver"], inplace=True) + + # Group by 'model' and calculate mean and SEM + grouped = df_clean.groupby(["model", "solver"]).agg(["mean", "sem"]) + xlabels = [f"{model}/{solver}" for model, solver in grouped.index] + + # Prepare data for plotting + means = grouped.xs("mean", axis=1, level=1) + errors = grouped.xs("sem", axis=1, level=1) + + # Plotting + means.plot(kind="bar", yerr=errors, capsize=4, ax=ax) # Removed 'stacked=True' + + ax.set_ylabel(metric) + ax.set_xticklabels(xlabels, rotation=30, ha="right") + ax.set_xlabel("model/solver") + ax.set_ylim(min_ylim, max_ylim) + + fig.tight_layout(pad=3.0) + fig.suptitle(PLOT_TITLES_BY_METRIC[metric] + f" on {dataset.capitalize()} dataset") + fig.savefig(outpath) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--log-dir", "-d", type=str, required=True) + parser.add_argument("--out-dir", "-o", type=str, default="./outputs") + args = parser.parse_args() + log_dir = Path(args.log_dir) + out_dir = Path(args.out_dir) + + out_dir.mkdir(exist_ok=True, parents=True) + + datasets = os.listdir(log_dir) + + for dataset in datasets: + print(f"Extracting data for eval dataset {dataset}...") + df = extract_metrics(log_dir / dataset) + + # Rename some of the solver values so they can be represented in the same plot. + df.loc[df["solver"] == "cot_hhh", "solver"] = "cot" + df.loc[df["solver"] == "hhh", "solver"] = "direct" + df.loc[df["solver"] == "fewshot_direct", "solver"] = "fewshot" + + # TODO: report directly as 'average_correct_calls' in future and remove this rename. + df.rename(columns={"average_retrieval_precision": "average_correct_calls"}, inplace=True) + df["correct_call_rate"] = df["average_correct_calls"] / df["average_retrieval_calls"] + df["invalid_call_rate"] = ( + df["average_invalid_retrieval_calls"] / df["average_retrieval_calls"] + ) + + print(f"Plotting other metrics for eval dataset {dataset}...") + + # Generate bar plots for all other metrics. + core_metrics = ( + [] + ) # ["baseline_accuracy", "retrieval_accuracy", "average_non_retrieval_bleu_score", "average_retrieval_bleu_score"] + auxiliary_metrics = [ + "correct_call_rate", + "invalid_call_rate", + "timeout_rate", + "ctx_len_exceeded_rate", + ] + for metric in core_metrics + auxiliary_metrics: + make_plot( + df[["model", "solver", metric]].copy(), + out_dir / f"{dataset}_{metric}.png", + metric, + dataset=dataset, + ) + + print(f"Plotting headline metrics for eval dataset {dataset}...") + + # Generate stacked bar plots for the two headline metrics. + for metric in ["overall_accuracy", "bleu_score"]: + make_side_bar_plot(df, out_dir / f"{dataset}_{metric}.png", metric, dataset=dataset) + + # Print numerical results (and compute % improvement metrics) + grouped = df.groupby(["model", "solver"]).agg(["mean", "sem"]) + for type, closedbook, openbook in [ + ( + "Translation (BLEU)", + "average_non_retrieval_bleu_score", + "average_retrieval_bleu_score", + ), + ("Non-translation (%)", "baseline_accuracy", "retrieval_accuracy"), + ]: + print(f"Improvement Metrics for {type} on {dataset.capitalize()} dataset") + improvement_rows = [] + for idx, row in grouped.iterrows(): + openbook_score = row[openbook]["mean"] + closedbook_score = row[closedbook]["mean"] + rel_improvement_score = (openbook_score - closedbook_score) / (1 - closedbook_score) + improvement_rows.append( + { + "model": idx[0], + "solver": idx[1], + "closedbook": closedbook_score, + "openbook": openbook_score, + "improvement": rel_improvement_score, + } + ) + improvement_df = pd.DataFrame(improvement_rows) + print(improvement_df) + # print to stdout as csv + print(improvement_df.to_csv(index=False)) diff --git a/evals/elsuite/skill_acquisition/scripts/run_experiments.sh b/evals/elsuite/skill_acquisition/scripts/run_experiments.sh new file mode 100755 index 0000000000..aaf81e0745 --- /dev/null +++ b/evals/elsuite/skill_acquisition/scripts/run_experiments.sh @@ -0,0 +1,76 @@ +logdir=./logs +outputdir=./outputs + +timestamp=$(date +%Y%m%d_%H%M%S) +logpathbase=$logdir/$timestamp/ + +size=full +num_repeats=1 +eval_variants_zero_shot=("skill_acquisition.miskito.zero_shot.$size") + +# Check for --num_repeats argument +for arg in "$@" +do + if [[ $arg == --num_repeats=* ]]; then + num_repeats="${arg#*=}" + fi +done + + +echo Running experiments and logging to $logpathbase + +declare -a ZEROSHOT_SOLVERS=( + # Solvers for gpt-3.5-turbo + "generation/direct/gpt-3.5-turbo" + "skill_acquisition/cot/gpt-3.5-turbo" + + + # Solvers for gpt-4-turbo-preview + "generation/direct/gpt-4-turbo-preview" + "skill_acquisition/cot/gpt-4-turbo-preview" +) + +declare -a FEWSHOT_SOLVERS=( + "miskito_all/fewshot_direct/gpt-3.5-turbo" + "miskito_all/fewshot_direct/gpt-4-turbo-preview" +) + +if [ ! -d "$logpathbase/miskito" ]; then + mkdir -p "$logpathbase/miskito" +fi + + +# Run zero-shot experiments. +for eval_variant in "${eval_variants_zero_shot[@]}" +do + if [[ $eval_variant == *"miskito"* ]]; then + record_path="$logpathbase/miskito" + fi + + for solver in "${ZEROSHOT_SOLVERS[@]}" + do + for ((i=1;i<=num_repeats;i++)); do + echo "Running $solver, iteration $i" + oaieval $solver $eval_variant --record_path "$record_path/$solver-$i.log" + done + done +done + +# Run few-shot experiments. +# Miskito +for solver in "${FEWSHOT_SOLVERS[@]}" +do + if [[ $solver == *"miskito"* ]]; then + for ((i=1;i<=num_repeats;i++)); do + echo "Running $solver, iteration $i" + oaieval $solver skill_acquisition.miskito.few_shot.$size --record_path "$logpathbase/miskito/$solver-$i.log" + done + fi +done + +echo Done running experiments, all logs in $logpathbase + +echo Producing plots, outputs to $outputdir + +# Produce results +python make_plots.py --log-dir $logpathbase --out-dir $outputdir \ No newline at end of file diff --git a/evals/elsuite/skill_acquisition/solvers.py b/evals/elsuite/skill_acquisition/solvers.py new file mode 100644 index 0000000000..16bd72f653 --- /dev/null +++ b/evals/elsuite/skill_acquisition/solvers.py @@ -0,0 +1,22 @@ +from evals.solvers.openai_assistants_solver import OpenAIAssistantsSolver +from evals.solvers.solver import SolverResult +from evals.task_state import TaskState + + +class SkillAcquisitionAssistantsSolver(OpenAIAssistantsSolver): + def _solve( + self, + task_state: TaskState, + **kwargs, + ) -> SolverResult: + + if task_state.current_state is not None: + task_state.task_description = """ +You are about to take a Q&A test in a language called {target_language}. You should do your best to answer each question correctly. Your score will be your number of correct answers divided by the number of questions. + +This is an open book test, and you have been provided with relevant materials to help you answer the questions via your file browser. + +To provide your final answer, output [ANSWER X], where X is the answer. For example, if you answer is 'this is a bear', output [ANSWER this is a bear] +""".strip() + + return super()._solve(task_state, **kwargs) diff --git a/evals/elsuite/skill_acquisition/task_description.py b/evals/elsuite/skill_acquisition/task_description.py new file mode 100644 index 0000000000..5a356179d0 --- /dev/null +++ b/evals/elsuite/skill_acquisition/task_description.py @@ -0,0 +1,3 @@ +TASK_DESCRIPTION = """You are about to take a Q&A test in a language called {target_language}. You should do your best to answer each question correctly. Your score will be your number of correct answers divided by the number of questions. + +To provide your final answer, output [ANSWER X], where X is the answer. For example, if you answer is 'this is a bear', output [ANSWER this is a bear]""" diff --git a/evals/elsuite/skill_acquisition/test_skill_acquisition.py b/evals/elsuite/skill_acquisition/test_skill_acquisition.py new file mode 100644 index 0000000000..079fdcb435 --- /dev/null +++ b/evals/elsuite/skill_acquisition/test_skill_acquisition.py @@ -0,0 +1,118 @@ +import json +import os +from pathlib import Path + +from evals.elsuite.skill_acquisition.eval import SkillAcquisition +from evals.elsuite.skill_acquisition.utils import ( + answer_detected, + process_answer, + process_view_instruction, + view_instruction_detected, +) +from evals.registry import Registry + +registry = Registry() + +dummy_eval_spec = { + "eval_registry_path": Path("evals/registry"), + "completion_fns": [registry.make_completion_fn("gpt-4")], + "samples_jsonl": "skill_acquisition/miskito/miskito_test.jsonl", + "target_language": "miskito", + "n_samples": 5, + "knowledge_base_directory": "skill_acquisition/miskito/knowledge_base/", + "max_replies": 50, +} + + +def test_answer_detected(): + assert answer_detected("[ANSWER foo]") is True + assert answer_detected("[ANSWER: foo]") is True + assert answer_detected("ANSWER foo") is False + assert answer_detected("[ANSWER foo") is False + assert answer_detected("ANSWER foo]") is False + assert answer_detected("[ANSWER foo][ANSWER bar]") is True + + +def test_view_instruction_detected(): + SkillAcquisition(**dummy_eval_spec) + assert view_instruction_detected("[VIEW file1]") is True + assert view_instruction_detected("[VIEW: file1]") is True + assert view_instruction_detected("[VIEW file1 section1]") is True + assert view_instruction_detected("[VIEW: file1 section1]") is True + assert view_instruction_detected("VIEW file1") is False + assert view_instruction_detected("[VIEW file1") is False + assert view_instruction_detected("VIEW file1]") is False + assert view_instruction_detected("[VIEW file1][VIEW file2]") is True + assert view_instruction_detected("[VIEW: file1][VIEW: file2]") is True + + +def test_process_answer(): + SkillAcquisition(**dummy_eval_spec) + assert process_answer("[ANSWER foo]") == "foo" + assert process_answer("[ANSWER: foo]") == "foo" + assert process_answer("[ANSWER foo bar baz]") == "foo bar baz" + assert process_answer("[ANSWER: foo bar baz]") == "foo bar baz" + assert process_answer("[ANSWER foo][ANSWER bar]") == "bar" + assert process_answer("[ANSWER foo][ANSWER bar") == "foo" + + +def test_process_view_instruction(): + SkillAcquisition(**dummy_eval_spec) + assert process_view_instruction("[VIEW file1]") == ("file1", None) + assert process_view_instruction("[VIEW: file1]") == ("file1", None) + assert process_view_instruction("[VIEW file1 section1]") == ( + "file1", + "section1", + ) + assert process_view_instruction("[VIEW: file1 section1]") == ( + "file1", + "section1", + ) + assert process_view_instruction("[VIEW file1][VIEW file2]") == ( + "file2", + None, + ) + assert process_view_instruction("[VIEW: file1][VIEW: file2]") == ( + "file2", + None, + ) + assert process_view_instruction("[VIEW file1 section1][VIEW file2 section2]") == ( + "file2", + "section2", + ) + + +def test_process_view_instruction_spaces_and_quotes(): + assert process_view_instruction("[VIEW file1 sectionpart1 sectionpart2]") == ( + "file1", + "sectionpart1 sectionpart2", + ) + assert process_view_instruction("[VIEW file1 sectionpart1 'sectionpart2']") == ( + "file1", + "sectionpart1 'sectionpart2'", + ) + + +def test_view_content(): + skill_acquisition_eval = SkillAcquisition(**dummy_eval_spec) + + # Create a file to view first. + filepath = skill_acquisition_eval.knowledge_base_directory / "test_file.jsonl" + with open(filepath, "w") as f: + f.write(json.dumps({"title": "foo", "content": "Test file contents."}) + "\n") + + content, sections_visible_to_model, sections_viewed = skill_acquisition_eval._view_content( + "test_file.jsonl" + ) + assert content == "Table of contents for test_file.jsonl: {'foo'}." + assert sections_visible_to_model == {"test_file.jsonl": {"foo"}} + assert sections_viewed == {"test_file.jsonl": {"Table of Contents"}} + + content, sections_visible_to_model, sections_viewed = skill_acquisition_eval._view_content( + "test_file.jsonl", "foo" + ) + assert content == "Test file contents." + assert sections_visible_to_model == {"test_file.jsonl": {"foo"}} + assert sections_viewed == {"test_file.jsonl": {"Table of Contents", "foo"}} + + os.remove(filepath) diff --git a/evals/elsuite/skill_acquisition/utils.py b/evals/elsuite/skill_acquisition/utils.py new file mode 100644 index 0000000000..ef0ef09349 --- /dev/null +++ b/evals/elsuite/skill_acquisition/utils.py @@ -0,0 +1,179 @@ +import json +import logging +import random +import re +from math import sqrt +from typing import Callable, Dict, List, Set, Union + +import numpy as np +from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu + +ANSWER_PATTERN = r"\[ANSWER:? .+?\]" +VIEW_PATTERN = r"\[VIEW:? .+?\]" +PROMPTS = { + "retrieval_instructions": ( + "This is an open-book test." + " If you are unsure about a question or unable to answer, you may read through a number of" + " knowledge sources that are made available to you in the form of text files." + " To open a file, output [VIEW X] where X is the name of the source you would like." + " For example, to view a file called help.jsonl, output [VIEW help.jsonl]." + " When you first open a file I will tell you its table of contents." + " You may view a particular section by outputting [VIEW X Y] where X is the name of" + " the file and Y is the name of the section you would like to read." + " The section titles won't necessarily match your exact task." + " You should explore them for any useful information, such as key words or phrases," + " before giving your answer. You should always try to provide an answer," + " even if you are not confident. Failure to provide an answer is penalised" + " more strongly than incorrect answers." + "\nHere are the sources available to you: {list_of_files}." + ), + "intermediate_prompt": "You've already viewed the following files and sections: {sections}.\nYou can view another file or section by outputting [VIEW X] or [VIEW X Y], or you can answer the question by outputting [ANSWER X].", + "present_content": "You asked to view file {file}, section {section}. Here is the content: {content}", + "wrong_file": "You tried to view {file}, which does not exist in the knowledge base. Choose another file from {knowledge_base}.", + "wrong_section": "You tried to view section {section} in file {file}, which does not exist. The table of contents for that file contains: {table_of_contents}.", +} + +logger = logging.getLogger(__name__) + + +def answer_detected(output: str) -> bool: + return len(re.findall(ANSWER_PATTERN, output)) > 0 + + +def view_instruction_detected(output: str) -> bool: + return len(re.findall(VIEW_PATTERN, output)) > 0 + + +def process_answer(output: str) -> str: + """Extracts the answer from model output. + The answer looks like [ANSWER X], where X is the answer. + + Args: + output (str): model output + + Returns: + str: answer provided by the model + """ + maybe_multiple_answers = re.findall(ANSWER_PATTERN, output) + + # Sanity check – this should never happen. + assert len(maybe_multiple_answers) > 0, f"No answer detected in {output}." + + if len(maybe_multiple_answers) > 1: + logger.debug( + f"Multiple answers detected, using only the final answer: {maybe_multiple_answers}" + ) + + final_answer_instruction = maybe_multiple_answers[-1] + final_answer = " ".join(final_answer_instruction.split(" ")[1:])[:-1] + + return final_answer + + +def process_view_instruction(output: str) -> Union[tuple[str, str], tuple[str, None]]: + """Extracts the target of a view instruction from model output. + The view instruction looks like [VIEW X Y], where X is a file name and Y is a section name. + This function extracts X and Y. + + Args: + output (str): model output + + Returns: + Union[tuple[str, str], tuple[str, None]]: tuple of file name and if applicable section name to view + """ + maybe_multiple_views = re.findall(VIEW_PATTERN, output) + + # Sanity check – this should never happen. + assert len(maybe_multiple_views) > 0, f"No view instruction detected in {output}." + + if len(maybe_multiple_views) > 1: + logger.debug( + f"Multiple view instructions detected, using only the final instruction: {maybe_multiple_views}" + ) + + final_view_instruction = maybe_multiple_views[-1][1:-1].split(" ")[1:] + file = final_view_instruction[0].strip() + + section = ( + None if len(final_view_instruction) == 1 else " ".join(final_view_instruction[1:]).strip() + ) + + return (file, section) + + +def _get_average_metric( + results: List[Dict[str, str]], metric_fn: Callable[List[Dict[str, str]], List[float]] +) -> float: + total_metric = sum(metric_fn(results)) + num_total = len(results) + if num_total == 0: + return float("nan") + else: + return total_metric / num_total + + +def get_bootstrap_accuracy_std(results: List[Dict[str, str]], num_samples: int = 1000) -> float: + results = [sample for sample in results if sample["question_type"] != "translation"] + vals = [result["correct"] for result in results] + return np.std([np.mean(random.sample(vals, len(vals) // 2)) for _ in range(1000)]) + + +def render_intermediate_prompt(sections_viewed: Dict[str, Set]) -> str: + return PROMPTS["intermediate_prompt"].format( + sections=json.dumps( + {k: list(v) for k, v in sections_viewed.items()}, indent=4 + ) # Cannot serialise sets directly. + ) + + +def get_question_type(question: str) -> str: + return "translation" if question.strip().startswith("Translate") else "non-translation" + + +def get_average_bleu_score(results: List[Dict[str, str]]) -> float: + results = [sample for sample in results if sample["question_type"] == "translation"] + return _get_average_metric( + results, + lambda samples: [ + get_bleu_score(sample["expected"][0], sample["parsed_output"]) for sample in samples + ], + ) + + +def get_bleu_score(expected: str, sampled: str) -> float: + punctuation = r"[^\w\s]" + + return sentence_bleu( + [re.sub(punctuation, "", expected).split()], + re.sub(punctuation, "", sampled).split(), + smoothing_function=SmoothingFunction().method1, + ) + + +def get_accuracy(results: List[Dict[str, str]]) -> float: + results = [sample for sample in results if sample["question_type"] != "translation"] + return _get_average_metric( + results, lambda samples: [int(sample["correct"]) for sample in samples] + ) + + +def get_average_retrieval_calls(results: List[Dict[str, str]]) -> float: + return _get_average_metric( + results, lambda samples: [sample["total_retrieval_calls"] for sample in samples] + ) + + +def get_average_invalid_retrieval_calls(results: List[Dict[str, str]]) -> float: + return _get_average_metric( + results, lambda samples: [sample["invalid_retrieval_calls"] for sample in samples] + ) + + +def get_average_retrieval_precision(results: List[Dict[str, str]]) -> float: + return _get_average_metric( + results, lambda samples: [sample["lesson_retrieval_calls"] for sample in samples] + ) + + +def get_std_of_difference(baseline_std: float, retrieval_std: float) -> float: + return sqrt(baseline_std**2 + retrieval_std**2) diff --git a/evals/elsuite/solver_tools_convo.py b/evals/elsuite/solver_tools_convo.py new file mode 100644 index 0000000000..8a13adf80b --- /dev/null +++ b/evals/elsuite/solver_tools_convo.py @@ -0,0 +1,240 @@ +import copy +import logging +import re +from dataclasses import dataclass +from typing import Any, Optional + +from evals.elsuite.bugged_tools.tools import Tool, ToolTaskState +from evals.solvers.solver import Solver, SolverResult +from evals.task_state import Message, TaskState + +logger = logging.getLogger(__name__) + + +@dataclass +class ToolCall: + tool_name: str + input: str + output: Any + + +@dataclass +class ParsedSolverResult: + tool_calls: list[ToolCall] + final_answer: Optional[str] + + +@dataclass +class RunnerResult: + final_task_state: ToolTaskState + final_solver_result: SolverResult + metrics: dict + + +class Runner: + def __init__( + self, + solver: Solver, + sample: Any, + name_to_tool: dict, + max_turns: int, + default_task_description: str, + default_reminder_message: str, + ): + self.solver = solver + self.sample = sample + self.name_to_tool = name_to_tool + self.max_turns = max_turns + self.default_task_description = default_task_description + self.default_reminder_message = default_reminder_message + + def run(self) -> RunnerResult: + # Prepare initial task state + tools = self.name_to_tool.values() + tool_names_and_descriptions = self._get_tool_names_and_descriptions(tools) + task_description = self.default_task_description.format( + tool_names_and_descriptions=tool_names_and_descriptions + ) + task_message = self.sample["task"] + messages = [ + Message(role="user", content=task_message), + ] + task_state = TaskState( + task_description=task_description, + messages=messages, + current_state=None, + ) + + # Loops until solver completes task or hits turn limit + turn = 0 + final_answer = None + while turn < self.max_turns: + # Get result from solver + solver_result = self.solver(task_state) + parsed_solver_result = self._parse_solver_result(solver_result) + final_answer = parsed_solver_result.final_answer + + # If solver failed to call tool or give final answer, prompt them to try again + if parsed_solver_result.tool_calls == [] and final_answer is None: + content = self.default_reminder_message + task_state = self._add_eval_message(task_state, solver_result, content=content) + turn += 1 + continue + + if final_answer is not None: + return self._finish_run(task_state, solver_result, final_answer, turn) + + # Run tools. If solver gave tool incorrect input, prompt them to try again. + assert parsed_solver_result.tool_calls != [] + tool_outputs = [self._run_tool_call(i) for i in parsed_solver_result.tool_calls] + if any([i is None for i in tool_outputs]): + content = self.default_reminder_message + task_state = self._add_eval_message(task_state, solver_result, content=content) + turn += 1 + continue + + # Add user message containing tool outputs + task_state = self._add_tool_outputs(task_state, solver_result, tool_outputs) + turn += 1 + + return self._finish_run(task_state, solver_result, None, turn) + + def _get_tool_names_and_descriptions(self, tools: list[Tool]): + """ + Given sequence of tools, creates a string of each tools name + and description, each tool's info separated by a newline + """ + s = "" + for tool in tools: + s += f"{tool._name}: {tool._desc}\n" + return s + + def _parse_solver_result(self, solver_result: SolverResult) -> ParsedSolverResult: + output = solver_result.output + tool_calls = self._parse_tool_calls(output) + final_answer = self._parse_final_answer(output) + return ParsedSolverResult(tool_calls=tool_calls, final_answer=final_answer) + + def _parse_tool_calls(self, output: str) -> Optional[list[ToolCall]]: + tool_message_matches = self._find_tool_messages(output) + if tool_message_matches == []: + return [] + + tool_calls = [] + for tool_name, tool_message in tool_message_matches: + # Log warning if solver calls a tool that doesn't exist + try: + self.name_to_tool[tool_name] + except KeyError: + logger.warn(f"Solver tried to call '{tool_name}' tool which doesn't exist!") + continue + + tool_call = ToolCall(tool_name=tool_name, input=tool_message, output=None) + tool_calls.append(tool_call) + return tool_calls + + def _find_tool_messages(self, text: str) -> list[tuple[str, str]]: + """ + Finds all tool calls, which are formatted [NAME: INPUT], + where NAME != "Answer" and NAME != "Bugged" + """ + pattern = r"\(@(?!Answer|Bugged)(\w+): (.+?)\)" + matches = re.findall(pattern, text, re.DOTALL) + return matches + + def _parse_final_answer(self, output: str) -> Optional[str]: + """ + If a final answer exists of form [Answer: OUTPUT], returns the output, + otherwise returns None + """ + match = re.search(r"\(@Answer: (.*?)\)", output, re.DOTALL) + return match.group(1) if match else None + + def _run_tool_call(self, tool_call: ToolCall) -> ToolCall: + # Prepare task state + tool_name = tool_call.tool_name + tool = self.name_to_tool[tool_name] + tool_input = tool_call.input + tool_desc = self.name_to_tool[tool_name]._desc + + # Remove quotes if solver wrapped input + if tool_input.startswith(("'", '"')) and tool_input.endswith(("'", '"')): + tool_input = tool_input[1:-1] + + task_description = ( + f"Your name is {tool_name}. A description of your purpose is shown below:\n{tool_desc}" + ) + messages = [Message(role="user", content=tool_input)] + task_state = ToolTaskState( + task_description=task_description, messages=messages, current_state=None + ) + try: + out = tool(task_state) + except (TypeError, ValueError, IndexError): + out = None + + if out is None: + return None + + tool_call.output = out.output + return tool_call + + def _add_eval_message( + self, + task_state: TaskState, + solver_output: SolverResult, + content: str, + ) -> TaskState: + messages = copy.deepcopy(task_state.messages) + messages.append(Message(role="assistant", content=solver_output.output)) + # NOTE: we assume that the order of tool_outputs is the same as the order of tool_calls + + messages.append(Message(role="user", content=content)) + new_task_state = TaskState( + task_description=task_state.task_description, + messages=messages, + current_state=None, + ) + return new_task_state + + def _add_tool_outputs( + self, + task_state: TaskState, + solver_output: SolverResult, + tool_outputs: list[ToolCall], + ) -> TaskState: + content = "" + for tool_output in tool_outputs: + name = tool_output.tool_name + input = tool_output.input + output = tool_output.output + content += f"{name} output on input {input}: {output}\n" + + return self._add_eval_message(task_state, solver_output, content) + + def _finish_run( + self, + final_task_state: TaskState, + solver_result: SolverResult, + final_answer: Optional[str], + turn: int, + ) -> RunnerResult: + expected_answer = self.sample["answer"] + is_correct = False + if final_answer is not None: + final_answer = final_answer.lower().strip() + # Remove quotes if solver wrapped input + if final_answer.startswith(("'", '"')) and final_answer.endswith(("'", '"')): + final_answer = final_answer[1:-1] + is_correct = final_answer == expected_answer.lower().strip() + + metrics = { + "is_correct": is_correct, + "num_turns": turn + 1, # zero-indexed, + } + + return RunnerResult( + final_task_state, + solver_result, + metrics, + ) diff --git a/evals/elsuite/track_the_stat/README.md b/evals/elsuite/track_the_stat/README.md new file mode 100644 index 0000000000..20c1580b2f --- /dev/null +++ b/evals/elsuite/track_the_stat/README.md @@ -0,0 +1,134 @@ +# Track the Stat + +This eval measures how well models can implicitly keep track of task state, by +asking models to compute the rolling median or the rolling mode over a sequence +of integers. + +## Usage + +Run with: + +```bash +oaieval track_the_stat +``` + +We have found that `generation/direct/gpt-4-0125-preview` works well on this +eval. For more examples of tested solvers, see +[`./scripts/run_experiments.sh`](./scripts/run_experiments.sh). + +## Evaluation Process + +The evaluation process is as follows for a given sample from our dataset: + +1. The `TASK_DESCRIPTION` prompt is shown to the solver. +2. The sample contains an integer to use as a seed for a random number + generator. +3. The random number generator generates 300 random integers between 0 and 100, + with replacement. +4. The integers are shown one by one to the solver. +5. At each turn (i.e., after each integer is shown), the solver needs to respond + with the current rolling median or the current rolling mode of the integers + seen so far. +6. The solver's response is parsed and compared to the correct rolling median or + rolling mode. +7. If the solver's response is incorrect or a violation is raised (answered in + the incorrect format), the evaluation stops and we measure how many turns the + solver lasted for. If the solver's response is correct, we move on to the + next integer. + +## Prompts + +We refer readers to the [`./prompts/`](./prompts/) folder for the +`TASK_DESCRIPTION` used in the eval. + +## Metrics + +Below are the metrics returned by the eval: + + +| **Metric** | **Notes** | +|------------------- |-------------------------------------------------------------------------------------------------------------------------------------------- | +| avg_max_length | The maximum sequence length the model can handle before failing, averaged across the samples. Higher is better. Best possible is 300. | +| stddev_max_length | The standard deviation on the above. | +| median_max_length | The median of the maximum sequence length the model can handle before failing, across the samples. Higher is better. Best possible is 300. | +| max_max_length | The maximum sequence length the model handled before failing across all samples. | +| min_max_length | The minimum sequence length the model handled before failing across all samples. | +| violation_rate | how often the model responds in an invalid format. i.e. not using the `[: ]` format. | + + +## Variants + +The eval has two variants: median and mode. In the median variant, the solver +needs to track the rolling median. In the mode variant, the solver needs to +track the rolling mode. + +```bash +oaieval track_the_stat. +``` + +## Custom Solvers + +We implement 3 custom solvers for this eval in [./solvers.py](./solvers.py) + +1. `ExplicitStateSolver`: A nested solver that injects an explicit + representation of the task state after each number is seen. For example, for + the median task we inject the sorted list of numbers seen so far. For the + mode task, we inject a dictionary that maps each number seen so far to its + count. We view this solver as a baseline for the task, providing the + performance of the models on _explicit_ state tracking, rather than the + default _implicit_ state tracking. +2. `RandomBaselineSolver`: A solver that randomly chooses a number from the + numbers seen so far as the rolling median or mode. In case of even length + lists in the median variant, it chooses two random numbers and returns their + arithmetic mean. We view this baseline as equivalent to randomly guessing. +3. `TrackTheStatHuman`: A helper solver class that wraps the `HumanCliSolver` + class such that users do not have to wrap their answer in the + `[median: ]` or `[mode: ]` format and can instead just + directly type the number. + +## Token Usage Estimates + +Below are token usage estimates for a given run (one run = all samples) of the +eval. + +For the mode task: + +| Model (state tracking) | Input | Output | Total | +| ----------------------------- | --------- | --------- | ---------- | +| gpt-3.5-turbo-0125 (implicit) | 670,000 | 10,000 | 680,000 | +| gpt-3.5-turbo-0125 (explicit) | 2,710,000 | 30,000 | 2,740,000 | +| gpt-4-base (implicit) | 9,030,000 | 2,110,000 | 11,150,000 | +| gpt-4-base (explicit) | 3,720,000 | 960,000 | 4,680,000 | +| gpt-4-0125-preview (implicit) | 3,050,000 | 30,000 | 3,080,000 | +| gpt-4-0125-preview (explicit) | 8,580,000 | 50,000 | 8,630,000 | + +For the median task: + +| Model (state tracking) | Input | Output | Total | +| ----------------------------- | --------- | ------- | --------- | +| gpt-3.5-turbo-0125 (implicit) | 430,000 | 10,000 | 440,000 | +| gpt-3.5-turbo-0125 (explicit) | 880,000 | 10,000 | 890,000 | +| gpt-4-base (implicit) | 2,900,000 | 760,000 | 3,660,000 | +| gpt-4-base (explicit) | 3,250,000 | 810,000 | 4,060,000 | +| gpt-4-0125-preview (implicit) | 690,000 | 10,000 | 700,000 | +| gpt-4-0125-preview (explicit) | 1,430,000 | 20,000 | 1,450,000 | + +## Future modifications + +- Identify new variants of the task beyond median or mode, where the explicit + state is either impossible to represent or not useful for the task. This would + allow us to more comfortably measure the implicit state tracking, even on CoT + solvers. +- Identify more realistic and/or complex tasks. +- Introduce distractors. + +## Version History + +- v0: Initial version released + +## Contribution Statement + +Eval design, implementation, and results evaluation were primarily conducted by +Giulio Starace, under the guidance of (alphabetically by last-name) Steven +Adler, Andrei Alexandru, James Aung, and Chan Jun Shern who provided research +input, report revisions, and project management support. diff --git a/evals/elsuite/track_the_stat/eval.py b/evals/elsuite/track_the_stat/eval.py new file mode 100644 index 0000000000..d1ca65d719 --- /dev/null +++ b/evals/elsuite/track_the_stat/eval.py @@ -0,0 +1,96 @@ +import logging +import random +from typing import Any, Optional + +import numpy as np + +from evals.elsuite.track_the_stat import prompts, utils +from evals.eval import SolverEval +from evals.record import RecorderBase, record_metrics +from evals.solvers.solver import Solver +from evals.task_state import Message, TaskState + +logging.getLogger("httpx").setLevel(logging.WARNING) +logger = logging.getLogger(__name__) + + +class TrackTheStat(SolverEval): + def __init__(self, task: str, n_samples: Optional[int] = 250, *args, **kwargs): + super().__init__(*args, **kwargs) + assert task in [ + "median", + "mode", + ], f"task must be either 'median' or 'mode', but got {task}" + self.task = task + # warn, color in yellow + logger.warning( + utils.yellow_string( + "By nature of what is being evaluated, this eval assumes that the " + "solver cannot make use of external scratchpads or similar solutions " + "to explicitly write down the task state at every step. Using solvers " + "that allow for this functionality will likely produce invalid results." + ) + ) + self.task_desc = prompts.TASK_DESCRIPTION.format( + task=task, + task_further_details=prompts.task_to_further_details[task], + task_example=prompts.task_to_example[task], + ) + self.task_fn = utils.task_to_fn[task] + self.n_samples = n_samples + self.rng = random.Random(self.seed) + + def eval_sample(self, solver: Solver, sample: Any, rng: random.Random) -> None: + capped_inf_list = np.random.default_rng(sample["seed"]).integers(0, 100, size=300) + metrics = self._eval_sample(solver, capped_inf_list) + + record_metrics(**metrics) + + def _eval_sample(self, solver: Solver, capped_inf_list: list[int]) -> dict: + violation = False + task_state = TaskState(task_description=self.task_desc, messages=[]) + for i, num in enumerate(capped_inf_list): + curr_list = capped_inf_list[: i + 1] + task_state.messages.append(Message(role="user", content=str(num))) + task_state.current_state = utils.compute_state(curr_list, self.task) + solver_output = solver(task_state).output + solver_response = utils.parse_solver_output(solver_output, self.task) + if solver_response is None: + violation = True + break + if round(solver_response, 1) != round(self.task_fn(curr_list), 1): + break + task_state.messages.append(Message(role="assistant", content=solver_output)) + + return { + "max_length": len(curr_list) - 1, + "violation": violation, + } + + def run(self, recorder: RecorderBase): + samples = self._get_samples() + self.eval_all_samples(recorder, samples) + logged_metrics: list[dict] = recorder.get_metrics() + + agg_metrics = self._compute_agg_metrics(logged_metrics) + return agg_metrics + + def _compute_agg_metrics(self, logged_metrics: list[dict]) -> dict: + max_lengths = np.array([x["max_length"] for x in logged_metrics]) + + agg_metrics = { + "avg_max_length": np.mean(max_lengths), + "stddev_max_length": np.std(max_lengths), + "median_max_length": np.median(max_lengths), + "max_max_length": np.max(max_lengths), + "min_max_length": np.min(max_lengths), + "violation_rate": np.mean([x["violation"] for x in logged_metrics]), + } + # necessary for serialization, json doesn't like np floats + agg_metrics = {k: float(v) for k, v in agg_metrics.items()} + return agg_metrics + + def _get_samples(self) -> list[dict]: + samples = [{"seed": num} for num in range(self.n_samples)] + samples = self.rng.sample(samples, self.n_samples) # shuffling + return samples diff --git a/evals/elsuite/track_the_stat/prompts/__init__.py b/evals/elsuite/track_the_stat/prompts/__init__.py new file mode 100644 index 0000000000..31d1b2de39 --- /dev/null +++ b/evals/elsuite/track_the_stat/prompts/__init__.py @@ -0,0 +1,27 @@ +from . import mode +from . import median + + +TASK_DESCRIPTION = """\ +In the following messages, you will be given a sequence of numbers.\ + At each turn, you will be shown a number as input, and you should respond with the\ + {task} of all the input numbers shown to you so far. + +{task_further_details} + +Here is an example of what this may look like. +{task_example} + +Format your response as [{task}: ] (square brackets included), as shown in\ +the transcript above. The task will begin now. +""" + +task_to_example = { + "median": median.MEDIAN_EXAMPLE, + "mode": mode.MODE_EXAMPLE, +} + +task_to_further_details = { + "median": median.MEDIAN_FURTHER_DETAILS, + "mode": mode.MODE_FURTHER_DETAILS, +} diff --git a/evals/elsuite/track_the_stat/prompts/median.py b/evals/elsuite/track_the_stat/prompts/median.py new file mode 100644 index 0000000000..aae3c0ecc8 --- /dev/null +++ b/evals/elsuite/track_the_stat/prompts/median.py @@ -0,0 +1,33 @@ +MEDIAN_EXAMPLE = """\ +```example +input: 1 +ideal_response: [median: 1]\ + # your response; 1 is the only number shown so far +--- +input: 2 +ideal_response: [median: 1.5]\ + # even number of numbers, so median = mean(1,2) = 1.5 +--- +input: 1 +ideal_response: [median: 1]\ + # 1 is now the middle number when sorting the numbers +--- +input: 3 +ideal_response: [median: 1.5]\ + # middle numbers are now 1 and 2, so once again median = mean(1,2) = 1.5 +--- +input: 3 +ideal_response: [median: 2]\ +# the sorted list is [1 1 2 3 3]; odd length, so median is the middle number, 2 +--- +input: 0 +ideal_response: [median: 1.5]\ +# the sorted list is [0 1 1 2 3 3]; even length, so median is mean(1,2) = 1.5 +```\ +""" + + +MEDIAN_FURTHER_DETAILS = """\ +NOTE: In case of lists containing an even number of elements, you should respond with the\ + arithmetic mean of the middle two numbers of the sorted list.\ +""" diff --git a/evals/elsuite/track_the_stat/prompts/mode.py b/evals/elsuite/track_the_stat/prompts/mode.py new file mode 100644 index 0000000000..5756e7e55c --- /dev/null +++ b/evals/elsuite/track_the_stat/prompts/mode.py @@ -0,0 +1,29 @@ +MODE_EXAMPLE = """\ +```example +input: 1 +ideal_response: [mode: 1]\ + # your response; 1 is the only number shown so far +--- +input: 2 +ideal_response: [mode: 2]\ + # 1 and 2 are tied modes (both appeared once), 2 > 1 +--- +input: 1 +ideal_response: [mode: 1]\ + # 1 now has appeared more than any other number +--- +input: 3 +ideal_response: [mode: 1] +--- +input: 3 +ideal_response: [mode: 3]\ + # 3 is tied with 1 in terms of appearances, 3 > 1 +--- +input: 0 +ideal_response: [mode: 3] +```\ +""" + +MODE_FURTHER_DETAILS = """\ +NOTE: In case of ties, you should respond with the largest number that is part of the tie.\ +""" diff --git a/evals/elsuite/track_the_stat/scripts/make_plots.py b/evals/elsuite/track_the_stat/scripts/make_plots.py new file mode 100644 index 0000000000..b40e4a3586 --- /dev/null +++ b/evals/elsuite/track_the_stat/scripts/make_plots.py @@ -0,0 +1,296 @@ +from pathlib import Path +import argparse +import json + +from tqdm.auto import tqdm +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + +from evals.utils import log_utils + + +def zero_if_none(input_num): + if input_num is None: + return 0 + else: + return input_num + + +MODELS = [ + "gpt-4-0125-preview", + "gpt-4-base", + "gpt-3.5-turbo-0125", + "gemini-pro-1.0", + "mixtral-8x7b-instruct", + "llama-2-70b-chat", + "random_baseline", + "human_baseline", +] +# separate list for OAI models for token counting, not supported in others. +OAI_MODELS = [ + "gpt-4-0125-preview", + "gpt-3.5-turbo-0125", + "gpt-4-base", +] + +STAT_TO_LABEL = { + "avg_max_length": "Average maximum sequence length achieved [no. of turns]", + "violation_rate": "Violation rate", +} + + +def make_results_dict(log_dir: Path) -> dict: + results_dict = prepare_results_dict() + results_dict = fill_results_dict(results_dict, log_dir) + return results_dict + + +def get_model(spec): + # this is hilariously ugly but it works for now (sorry) + if "gpt-4-turbo-preview" in spec["completion_fns"][0]: + return "gpt-4-0125-preview" + elif "gpt-3.5-turbo" in spec["completion_fns"][0]: + return "gpt-3.5-turbo-0125" + elif "gpt-4-base" in spec["completion_fns"][0]: + return "gpt-4-base" + elif "gemini-pro" in spec["completion_fns"][0]: + return "gemini-pro-1.0" + elif "mixtral-8x7b-instruct" in spec["completion_fns"][0]: + return "mixtral-8x7b-instruct" + elif "llama-2-70b-chat" in spec["completion_fns"][0]: + return "llama-2-70b-chat" + elif "random_baseline" in spec["completion_fns"][0]: + return "random_baseline" + elif "human" in spec["completion_fns"][0]: + return "human_baseline" + + +def get_state_tracking(spec): + if "explicit" in spec["completion_fns"][0]: + return "explicit" + else: + return "implicit" + + +def fill_results_dict(results_dict, log_dir): + print("Parsing logs...") + final_results = log_utils.get_final_results_from_dir(log_dir) + specs = log_utils.get_specs_from_dir(log_dir) + files = list(final_results.keys()) + + for file in tqdm(files): + final_result = final_results[file] + spec = specs[file] + task = spec["split"] + model = get_model(spec) + state_tracking = get_state_tracking(spec) + for stat in results_dict: + results_dict[stat][task][model][state_tracking]["raw"].append( + final_result[stat] + ) + # compute means/std_errs + for file in tqdm(files): + spec = specs[file] + task = spec["split"] + model = get_model(spec) + state_tracking = get_state_tracking(spec) + for stat in results_dict: + data_points = results_dict[stat][task][model][state_tracking]["raw"] + results_dict[stat][task][model][state_tracking]["mean"] = np.mean( + data_points + ) + results_dict[stat][task][model][state_tracking]["std_err"] = np.std( + data_points + ) / np.sqrt(len(data_points) if len(data_points) > 1 else 1) + return results_dict + + +def prepare_results_dict(): + results_dict = { + stat: { + task: { + model: { + state_tracking: {"raw": []} + for state_tracking in ["implicit", "explicit"] + } + for model in MODELS + } + for task in ["mode", "median"] + } + for stat in ["avg_max_length", "violation_rate"] + } + return results_dict + + +def make_bar_plot(results_dict: dict, task: str, stat: str, save_path: Path): + sns.set_context("paper") + sns.set_style("whitegrid") + + data = results_dict[stat][task] + + # the random baseline and human baseline aren't plotted as bars + models = MODELS[:-2] + + state_tracking_kinds = ["explicit", "implicit"] + + means = [ + [data[model][cat]["mean"] for cat in state_tracking_kinds] for model in models + ] + std_errs = [ + [data[model][cat]["std_err"] for cat in state_tracking_kinds] + for model in models + ] + cmap = plt.get_cmap("Paired") + colors = np.array([cmap(i) for i in range(len(state_tracking_kinds))]) + + # Plotting + x = np.arange(len(models)) # the label locations + + width = 0.4 + + fig, ax = plt.subplots(1, 1, figsize=(8, 6), dpi=300) + + explicit_bars = ax.barh( + x + width / 2, + [mean[0] for mean in means], + width, + xerr=[err[0] for err in std_errs], + label="Explicitly tracked state baseline", + color=colors[0], + ) + implicit_bars = ax.barh( + x - width / 2, + [mean[1] for mean in means], + width, + xerr=[err[1] for err in std_errs], + label="Implicitly tracked state", + color=colors[1], + ) + + ax.set_xlabel(STAT_TO_LABEL[stat]) + # maximum x + xerr value times 1.2 + x_max = ( + max([m for mean in means for m in mean]) + + max([e for err in std_errs for e in err]) + ) * 1.2 + ax.set_xlim([0, x_max]) + ax.set_yticks(x) + ax.set_yticklabels(models) + + ax.bar_label(implicit_bars, padding=3, fmt="%.2f") + ax.bar_label(explicit_bars, padding=3, fmt="%.2f") + + # plot random and human baselines + random_baseline = data["random_baseline"]["implicit"]["mean"] + random_err = data["random_baseline"]["implicit"]["std_err"] + ax.axvline(random_baseline, color="red", linestyle="--", label="Random baseline") + ax.axvspan( + random_baseline - random_err, + random_baseline + random_err, + color="red", + alpha=0.05, + ) + + human_baseline = data["human_baseline"]["implicit"]["mean"] + human_err = data["human_baseline"]["implicit"]["std_err"] + ax.axvline( + human_baseline, + color="#366a9d", + linestyle=":", + label="Human baseline (implicit)", + ) + + ax.axvspan( + human_baseline - human_err, + human_baseline + human_err, + color="#366a9d", + alpha=0.05, + ) + + # get rid of horizontal grid lines + ax.grid(axis="y", which="both") + + ax.legend() + + fig.tight_layout() + + plt.savefig(save_path, bbox_inches="tight", dpi=300) + + +def count_tokens(log_dir) -> dict[str, dict[str, dict[str, int]]]: + """ + model -> task -> input, output, total tokens + """ + token_counts = { + model: { + task: { + state_tracking: {kind: 0 for kind in ["input", "output", "total"]} + for state_tracking in ["implicit", "explicit"] + } + for task in ["mode", "median"] + } + for model in OAI_MODELS + } + globbed_logs = list(log_dir.glob("*.log")) + already_examined = set() + for log in tqdm(globbed_logs, total=len(globbed_logs), desc="Counting tokens"): + spec = log_utils.extract_spec(log) + task = spec["split"] + model = get_model(spec) + state_tracking = get_state_tracking(spec) + + if model not in OAI_MODELS: + continue + + # dont care about repeats, this is a rough estimate anyway + if (model, task, state_tracking) in already_examined: + continue + already_examined.add((model, task, state_tracking)) + + samplings = log_utils.extract_individual_results(log, "sampling") + for sampling in samplings: + usage = sampling["usage"] + token_counts[model][task][state_tracking]["input"] += zero_if_none( + usage["prompt_tokens"] + ) + token_counts[model][task][state_tracking]["output"] += zero_if_none( + usage["completion_tokens"] + ) + token_counts[model][task][state_tracking]["total"] += zero_if_none( + usage["total_tokens"] + ) + return token_counts + + +def main(args: argparse.Namespace): + log_dir = Path(args.log_dir) + save_dir = Path(args.save_dir) + save_dir.mkdir(exist_ok=True, parents=True) + + results_dict = make_results_dict(log_dir) + + for stat in tqdm(results_dict.keys(), desc=f"Plotting..."): + for task in tqdm(["mode", "median"], desc=f"Plotting {stat}"): + save_path = save_dir / f"{task}_{stat}.png" + make_bar_plot(results_dict, task, stat, save_path) + save_path = save_dir / f"{stat}.json" + with open(save_path, "w") as f: + json.dump(results_dict[stat], f, indent=2) + + token_counts = count_tokens(log_dir) + save_path = save_dir / "token_counts.json" + with open(save_path, "w") as f: + json.dump(token_counts, f, indent=2) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--log_dir", type=str, required=True, help="Where the logs are stored" + ) + parser.add_argument( + "--save_dir", type=str, required=True, help="Where to save the plots" + ) + args = parser.parse_args() + main(args) diff --git a/evals/elsuite/track_the_stat/scripts/run_experiments.sh b/evals/elsuite/track_the_stat/scripts/run_experiments.sh new file mode 100644 index 0000000000..8307866418 --- /dev/null +++ b/evals/elsuite/track_the_stat/scripts/run_experiments.sh @@ -0,0 +1,92 @@ +#!/bin/bash + +usage() { + echo "Usage: $0 -l logdir" + echo " -l logdir Specify the directory for log files" + exit 1 +} + +# Check if no arguments were provided +if [ $# -eq 0 ]; then + usage + exit 1 +fi + +# Parse command-line options +while getopts 's:l:' flag; do + case "${flag}" in + l) logdir=${OPTARG} ;; + *) usage ;; + esac +done + +# Check if mandatory arguments were provided +if [ -z "$logdir" ]; then + usage + exit 1 +fi + +NUM_REPEATS=3 + +export EVALS_THREADS=10 +export EVALS_THREADS_TIMEOUT=5 + +declare -a SOLVERS=( + # 4-turbo-preview + "generation/direct/gpt-4-turbo-preview" + "track_the_stat/explicit_state/gpt-4-turbo-preview" + # 3.5-turbo + "generation/direct/gpt-3.5-turbo" + "track_the_stat/explicit_state/gpt-3.5-turbo" + # 4-base + "generation/hhh/gpt-4-base" + "track_the_stat/explicit_state/hhh/gpt-4-base" + # gemini pro + "generation/direct/gemini-pro" + "track_the_stat/explicit_state/gemini-pro" + # mixtral-8x7b-instruct + "generation/direct/mixtral-8x7b-instruct" + "track_the_stat/explicit_state/mixtral-8x7b-instruct" + # llama chat 70b + "generation/direct/llama-2-70b-chat" + "track_the_stat/explicit_state/llama-2-70b-chat" + # random baseline + "track_the_stat/random_baseline" +) +declare -a TASKS=( + "mode" + "median" +) + +# Check if GEMINI_API_KEY is set +if [ -z "$GEMINI_API_KEY" ]; then + echo "Enter your Gemini API Key:" + read -s GEMINI_API_KEY + export GEMINI_API_KEY +fi + +# Check if TOGETHER_API_KEY is set +if [ -z "$TOGETHER_API_KEY" ]; then + echo "Enter your Together API Key:" + read -s TOGETHER_API_KEY + export TOGETHER_API_KEY +fi + +start_time=$SECONDS +for ((i = 1; i <= NUM_REPEATS; i++)); do + for task in "${TASKS[@]}"; do + for solver in "${SOLVERS[@]}"; do + if [[ $solver == *"gemini"* ]]; then + export EVALS_SEQUENTIAL=1 + else + export EVALS_SEQUENTIAL=0 + fi + solver_dotted=${solver//\//.} + record_path="${logdir}/${solver_dotted}_${task}_${i}" + echo "Running $solver on $task (repeat $i)" + oaieval $solver "track_the_stat.${task}" \ + --record_path "$record_path.log" --seed $i + done + done +done +echo "Total time: $((SECONDS - start_time)) seconds" diff --git a/evals/elsuite/track_the_stat/solvers.py b/evals/elsuite/track_the_stat/solvers.py new file mode 100644 index 0000000000..65721002cc --- /dev/null +++ b/evals/elsuite/track_the_stat/solvers.py @@ -0,0 +1,98 @@ +import random +from typing import Any + +from evals.elsuite.track_the_stat import utils +from evals.solvers.solver import NestedSolver, Solver, SolverResult, SolverSpec +from evals.task_state import Message, TaskState + + +class ExplicitStateSolver(NestedSolver): + def __init__( + self, + underlying_solver: SolverSpec, + state_role: str = "assistant", + *args, + **kwargs, + ): + super().__init__(underlying_solver=underlying_solver, *args, **kwargs) + self.state_role = state_role + + @property + def underlying_solver(self) -> Solver: + return self.get_solver("underlying_solver") + + def _render_state(self, current_state: dict) -> str: + rendered_state_string = f"{current_state['state_label']}\n{current_state['state_data']}" + return rendered_state_string + + def _build_message(self, task_state: TaskState) -> str: + message_string = "The current state, useful for solving the task\n" + self._render_state( + task_state.current_state + ) + return Message(role=self.state_role, content=message_string) + + def _solve(self, task_state: TaskState) -> SolverResult: + precomputed_state_message = self._build_message(task_state) + task_state.messages.append(precomputed_state_message) + + solver_result = self.underlying_solver(task_state=task_state) + return solver_result + + +class RandomBaselineSolver(Solver): + def __init__(self, registry: Any = None, *args, **kwargs): + super().__init__() + + def _solve(self, task_state: TaskState) -> SolverResult: + task = task_state.current_state["task_name"] + random_output = self._task_solve(task, task_state) + solver_result = SolverResult(output=f"[{task}: {random_output}]") + return solver_result + + def _task_solve(self, task: str, task_state: TaskState) -> str: + if task == "mode": + return self._mode_solve(task_state) + elif task == "median": + return self._median_solve(task_state) + + def _mode_solve(self, task_state: TaskState) -> str: + """ + Picks a random number from the numbers seen so far + """ + numbers = list(task_state.current_state["state_data"].keys()) + random_mode = random.choice(numbers) + return str(random_mode) + + def _median_solve(self, task_state: TaskState) -> str: + """ + Picks a random number from the numbers seen so far + (in case of even number of numbers, picks the average of two random numbers) + """ + numbers = task_state.current_state["state_data"] + if len(numbers) % 2 == 0: + random_1, random_2 = random.choices(numbers, k=2) + random_median = (random_1 + random_2) / 2 + else: + random_median = random.choice(numbers) + return str(round(random_median, 1)) + + +class TrackTheStatHuman(NestedSolver): + def __init__(self, human_cli_solver: SolverSpec, *args, **kwargs): + super().__init__(human_cli_solver=human_cli_solver, *args, **kwargs) + + @property + def human_cli_solver(self) -> Solver: + return self.get_solver("human_cli_solver") + + def _solve(self, task_state: TaskState) -> SolverResult: + human_result = self.human_cli_solver(task_state=task_state) + task = task_state.current_state["task_name"] + # wrap the result in [: ] if not already wrapped + output = utils.parse_solver_output(human_result.output, task) + if output is None: # there is a violation -- output is not wrapped + return SolverResult( + output=f"[{task}: {human_result.output}]", + ) + else: # no violation -- output is already wrapped + return human_result diff --git a/evals/elsuite/track_the_stat/utils.py b/evals/elsuite/track_the_stat/utils.py new file mode 100644 index 0000000000..55467c5100 --- /dev/null +++ b/evals/elsuite/track_the_stat/utils.py @@ -0,0 +1,78 @@ +import re +from collections import Counter +from typing import Union + +import numpy as np + + +def yellow_string(str: str) -> str: + return f"\033[1;33m{str}\033[0m" + + +def median(numbers: list[int]) -> int: + """ + Returns the median of the given list of numbers. If the list has an even + number of elements, the arithmetic mean of the two middle elements of the + sorted list is returned. + """ + return np.median(numbers) + + +def mode(numbers: list[int]) -> int: + """ + Returns the mode of the given list of numbers. If there are multiple modes, + the largest mode is returned. + """ + frequency = {} + for number in numbers: + frequency[number] = frequency.get(number, 0) + 1 + + max_frequency = max(frequency.values()) + candidates = [number for number, freq in frequency.items() if freq == max_frequency] + + return max(candidates) + + +task_to_fn = {"median": median, "mode": mode} + + +def parse_solver_output(solver_output: str, task: str) -> Union[int, None]: + solver_string = solver_output.strip().lower() + pattern = rf"\[{task}: (\d+(?:\.\d+)?)\]" + + match = re.search(pattern, solver_string) + + if match: + try: + output = float(match.group(1)) + except ValueError: + output = None + else: + output = None + + return output + + +def compute_mode_state(curr_list: list[int]) -> dict: + counter = Counter(curr_list) + return dict(counter) + + +def compute_median_state(curr_list: list[int]) -> dict: + sorted_list = sorted(curr_list) + return sorted_list + + +def compute_state(curr_list: list[int], task) -> dict: + if task == "mode": + return { + "task_name": task, + "state_label": "number to count", + "state_data": compute_mode_state(curr_list), + } + else: + return { + "task_name": task, + "state_label": "sorted list of shown numbers", + "state_data": compute_median_state(curr_list), + } diff --git a/evals/elsuite/twenty_questions/eval.py b/evals/elsuite/twenty_questions/eval.py new file mode 100644 index 0000000000..3cb0d5c857 --- /dev/null +++ b/evals/elsuite/twenty_questions/eval.py @@ -0,0 +1,204 @@ +import logging +import random +import re +from typing import Any, Dict, List, Optional, Union + +import evals +import evals.metrics +from evals.api import CompletionFn +from evals.elsuite.twenty_questions.utils import PROMPTS, generate_task_state_for +from evals.eval import SolverEval +from evals.record import Recorder +from evals.registry import registry +from evals.solvers.human_cli_solver import HumanCliSolver +from evals.solvers.solver import Solver +from evals.solvers.utils import maybe_wrap_with_solver +from evals.task_state import Message + +logger = logging.getLogger(__name__) +WORD_PATTERN = r"\[GUESS (.*?)\]" + + +class TwentyQuestions(SolverEval): + def __init__( + self, + completion_fns: List[CompletionFn], + samples_jsonl: str, + gamemaster_spec: str, + max_questions: int = 20, + max_replies: int = 40, + num_shortlist_items: int = 20, + shortlist_variant: bool = False, + seed: int = 222024, + n_samples: Optional[int] = None, + *args, + **kwargs, + ): + super().__init__(completion_fns, seed=seed, *args, **kwargs) + + self.samples_jsonl = samples_jsonl + self.gamemaster_solver = maybe_wrap_with_solver( + registry.make_completion_fn(gamemaster_spec) + ) + self.max_questions = max_questions + + if max_replies < max_questions: + logger.warn( + f"max_replies ({max_replies}) is less than max_questions ({max_questions}). Setting max_replies to {max_questions + 20}" + ) + self.max_replies = max_replies if max_replies > max_questions else max_questions + 20 + self.num_shortlist_items = num_shortlist_items + self.shortlist_variant = shortlist_variant + + self.n_samples = n_samples + self.rng = random.Random(seed) + + def eval_sample(self, solver: Solver, sample: Dict, rng: random.Random) -> Dict[str, Any]: + assert "word" in sample, "Sample must contain 'word' field" + assert "difficulty" in sample, "Sample must contain 'difficulty' field" + + if not isinstance(solver, HumanCliSolver): + logging.info(f"Running sample: {sample['word']}") + + # Generate the shortlist for the current sample if applicable. + if self.shortlist_variant: + assert self.num_shortlist_items <= len( + self.shortlist + ), "Number of shortlist items must be less than or equal to the total number of samples." + shortlist_for_sample = rng.sample(self.shortlist, self.num_shortlist_items) + if sample["word"] not in shortlist_for_sample: + random_index = rng.randint(0, len(shortlist_for_sample) - 1) + shortlist_for_sample[random_index] = sample["word"] + else: + shortlist_for_sample = None + response = self._conversation_loop(solver, sample, shortlist_for_sample) + + return response + + def run(self, recorder: Recorder) -> Dict[str, Union[float, int]]: + samples = self.get_samples() + self.rng.shuffle(samples) + samples = samples[: self.n_samples] if self.n_samples else samples + + if self.shortlist_variant: + self.shortlist = [sample["word"] for sample in samples] + + self.eval_all_samples(recorder, samples) + events = recorder.get_events("match") + + scores = [event.data["score"] for event in events] + num_guesses = [event.data["num_guesses"] for event in events] + num_questions = [event.data["num_questions"] for event in events] + num_violations = [event.data["num_violations"] for event in events] + num_gamemaster_refusals = [event.data["num_gamemaster_refusals"] for event in events] + incorrect_guesses = [event.data["incorrect_guesses"] for event in events] + word_difficulties = [event.data["word_difficulty"] for event in events] + + return { + "score": sum(scores) / len(scores), + "accuracy": evals.metrics.get_accuracy(events), + "bootstrap_std": evals.metrics.get_bootstrap_accuracy_std(events), + "average_num_guesses": sum(num_guesses) / len(num_guesses), + "average_num_questions": sum(num_questions) / len(num_questions), + "average_num_violations": sum(num_violations) / len(num_violations), + "average_num_gamemaster_refusals": sum(num_gamemaster_refusals) + / len(num_gamemaster_refusals), + "average_num_incorrect_guesses": sum((len(ig) for ig in incorrect_guesses)) + / len(incorrect_guesses), + "average_word_difficulty": sum(word_difficulties) / len(word_difficulties), + } + + def _conversation_loop( + self, solver: Solver, sample: Dict, shortlist: Optional[List[str]] = None + ) -> Dict[str, Any]: + """Maintains a conversation between the guesser and the gamemaster until the maximum number of questions is reached, or until a correct guess is made. + + Args: + solver (Solver): any compatible solver, instantiated for the current sample. + sample (Dict): current sample – one word to guess, and its associated difficulty. + + Returns: + Dict[str, Any]: a dictionary containing the final result and metrics of the conversation. + """ + + metrics = { + "num_guesses": 0, + "num_questions": 0, + "num_violations": 0, + "num_guesser_replies": 0, # num_guesses + num_questions + num_violations + "num_gamemaster_refusals": 0, + "incorrect_guesses": [], + } + conversation = [] + + # Contains fall-back condition to avoid infinite loops for solvers which never output questions. + while ( + metrics["num_questions"] < self.max_questions + and metrics["num_guesser_replies"] < self.max_replies + ): + task_state = generate_task_state_for( + "guesser", conversation, max_questions=self.max_questions, shortlist=shortlist + ) + guesser_response = solver(task_state) + conversation += [Message(content=guesser_response.output, role="guesser")] + metrics["num_guesser_replies"] += 1 + + # Check if guess made: + match = re.search(WORD_PATTERN, guesser_response.output) + if match is not None: + metrics["num_guesses"] += 1 + guess = match.group(1) + if guess.lower() == sample["word"].lower(): + response = { + "correct": True, + "score": self.max_questions - metrics["num_questions"], + "expected": sample["word"], + "word_difficulty": sample["difficulty"], + "picked": guess, + "num_guesses": metrics["num_guesses"], + "num_questions": metrics["num_questions"], + "num_violations": metrics["num_violations"], + "num_gamemaster_refusals": metrics["num_gamemaster_refusals"], + "incorrect_guesses": metrics["incorrect_guesses"], + } + evals.record.record_match(**response) + return response + else: + metrics["incorrect_guesses"] += [guess] + conversation += [ + Message( + content=PROMPTS["incorrect_guess"].format(guess=guess), role="system" + ) + ] + continue + elif "?" in guesser_response.output.strip(): + metrics["num_questions"] += 1 + else: # Neither guess nor question. + # TODO: Maybe make the guesser retry here? + logger.warn( + f"Rule violation, no guess or question in output: {guesser_response.output}" + ) + metrics["num_violations"] += 1 + conversation += [Message(content=PROMPTS["rule_violation"], role="system")] + continue + + task_state = generate_task_state_for("gamemaster", conversation, sample["word"]) + gamemaster_response = self.gamemaster_solver(task_state) + conversation += [Message(content=gamemaster_response.output, role="gamemaster")] + if gamemaster_response.output.lower() == "skip": + metrics["num_gamemaster_refusals"] += 1 + + logger.info(f"Ran out of questions for word: {sample['word']}") + response = { + "correct": False, + "score": 0, + "expected": sample["word"], + "word_difficulty": sample["difficulty"], + "num_guesses": metrics["num_guesses"], + "num_questions": metrics["num_questions"], + "num_violations": metrics["num_violations"], + "num_gamemaster_refusals": metrics["num_gamemaster_refusals"], + "incorrect_guesses": metrics["incorrect_guesses"], + } + evals.record.record_match(**response) + return response diff --git a/evals/elsuite/twenty_questions/readme.md b/evals/elsuite/twenty_questions/readme.md new file mode 100644 index 0000000000..89f2e3ce6c --- /dev/null +++ b/evals/elsuite/twenty_questions/readme.md @@ -0,0 +1,82 @@ +# 20 Questions + +This eval tests models' ability to generate and iterate over hypotheses by playing the game of "20 questions". In 20 questions, one of the players – the "gamemaster" – thinks of a word (in our case a noun) and the other player needs to guess it. To help them guess, the player can ask up to 20 yes-or-no questions, which the gamemaster must answer. + +## Usage +Run with: +```bash +# Standard variant. +oaieval twenty_questions.full + +# Shortlist variant. +oaieval twenty_questions.shortlist.full +``` + +Where the solver can be any generation solver in `evals/registry/solvers/defaults.yaml`, e.g. `generation/cot/gpt-3.5-turbo-16k`, or the chain-of-thought solvers in `evals/registry/solvers/twenty_questions.yaml`. + +## Evaluation process +We run a dialogue loop between two models for each sample: the evaluated model and the "gamemaster". By default, the gamemaster is gpt-4-turbo-preview – but this can be updated by specifying a different solver in `evals/registry/evals/twenty_questions.yaml`. + +The dialogue continues until the word is guessed correctly, or until 20 questions have been asked, whichever comes first. We also terminate conversations that last longer than 40 replies, to ensure that models which do not ask questions don't have infinite conversations. Both the maximum questions and the maximum replies can be controlled from the eval YAML file. + +## Task State +The task state can be found in `twenty_questions/utils.py`; it reads: +``` +You are about to play the game '20 questions'. The other player has thought of a noun which you should try to guess. You can ask 20 yes/no questions, to which they will answer 'yes', 'no', or 'skip' (if they cannot answer your question). You should try to guess the word as soon as possible, using the least amount of questions. To guess a word, type [GUESS ] – for example to guess the word 'sibling', output [GUESS sibling]. Your score will be 0 if you do not guess correctly, and {max_questions} minus the number of questions you asked if you guess correctly. Start by asking your first question. +``` + +## Prompts +See `twenty_questions/utils.py` to review/adjust the prompts used in this eval. + +## Datasets + +We use a dataset of 207 words, 177 of which were from [this lexicon](https://github.com/mounicam/lexical_simplification), annotated by our team with a difficulty category. This dataset comprises: +- 47 words rated “easy”, e.g. ear, father, potato; +- 91 words rated “medium”, e.g. cloth, hike, discount; +- 69 words rated “hard”, e.g. prosperity, gland, philosopher; + +In addition to these common nouns, we include 30 proper nouns such as “Sherlock Holmes,” “The Beatles,” “Titanic,” and “Starbucks”, which span the easy and medium difficulties. + +## Metrics +We measure the score each model achieves, defined as `score = max_questions - questions_asked`. We also track the win-rate, i.e. the % of samples the model guesses correctly. Auxiliary metrics such as average number of average number of questions asked, average number of incorrect guesses, and average number of gamemaster refusals (i.e. situations where the gamemaster says 'skip') are also tracked. + + +## Variants + +We run two main variants of this evaluation: +- **standard**: the main variant +- **shortlist**: an easier variant where the evaluated model sees a shortlist of words in its system prompt. The word the gamemaster has selected is part of the list. In this variant, the evaluated model effectively has to narrow down the pool of candidate words until it finds the answer. + +## Token Usage Estimates + +Below is a rough estimate of the total number of tokens consumed by some variations the eval, including both input and output tokens: + +Variant | Model | Solver | Prompt tokens | Completion tokens | Total tokens +| --- | --- | --- | --- | --- | --- | +standard | direct | gpt-4-turbo-preview | 2,502,067 | 52,879 | 2,554,946 +standard | direct | gpt-4-base | 13,197,212 | 2,814,623 | 16,011,835 +standard | direct | gpt-3.5-turbo | 2,670,866 | 57,917 | 2,728,783 +standard | cot | gpt-4-turbo-preview | 73,765,861 | 1,881,455 | 75,647,316 +standard | cot | gpt-4-base | 51,777,817 | 6,397,472 | 58,175,289 +standard | cot | gpt-3.5-turbo | 38,236,500 | 199,831 | 38,436,331 +standard | cot | llama-2-70b | 6,785,634 | 581,421 | 7,367,055 +standard | cot | mixtral-8x7b-instruct | 175,956,903 | 5,327,393 | 181,284,296 +shortlist | direct | gpt-4-turbo-preview | 1,237,172 | 28,351 | 1,265,523 +shortlist | direct | gpt-4-base | 11,034,903 | 2,133,487 | 13,168,390 +shortlist | direct | gpt-3.5-turbo | 1,704,154 | 36,356 | 1,740,510 +shortlist | cot | gpt-4-turbo-preview | 10,951,215 | 545,945 | 11,497,160 +shortlist | cot | gpt-4-base | 45,591,363 | 596,429 | 46,187,792 +shortlist | cot | gpt-3.5-turbo | 19,798,263 | 165,731 | 19,963,994 +shortlist | cot | llama-2-70b | 5,980,667 | 528,879 | 6,509,546 +shortlist | cot | mixtral-8x7b-instruct | 143,646,924 | 4,315,806 | 147,962,730 + + +## Version History +v0: Initial version released + + +## Contribution statement + +Eval design, implementation, and results evaluation were primarily conducted by Andrei Alexandru with contributions from Dane Sherburn, under the guidance of (alphabetically by last-name) Steven Adler, James Aung, and Chan Jun Shern who scoped and managed the broader research project, including input on evaluation design, results analysis, and interpretation. + + diff --git a/evals/elsuite/twenty_questions/scripts/make_plots.py b/evals/elsuite/twenty_questions/scripts/make_plots.py new file mode 100644 index 0000000000..f07b76da5a --- /dev/null +++ b/evals/elsuite/twenty_questions/scripts/make_plots.py @@ -0,0 +1,142 @@ +import argparse +from pathlib import Path +import os + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns + +from evals.utils import log_utils + +PLOT_TITLES_BY_METRIC = { + "score": "Score", + "winrate": "Win-rate", + "bootstrap_std": "Bootstrapped standard deviation", + "average_num_guesses": "Average guesses per sample", + "average_num_questions": "Average questions per sample", + "average_num_violations": "Average rule violations per sample", + "average_num_gamemaster_refusals": "Average gamemaster refusals per sample", + "average_num_incorrect_guesses": "Average incorrect guesses per sample", + "average_word_difficulty": "Average word difficulty", +} + +HUMAN_BASELINE = { + "standard": { + "winrate": 0.0333, + "score": 0.1333, + "average_num_guesses": 0.3666, + "average_num_questions": 19.8666, + "average_num_violations": 0.62, + "average_num_gamemaster_refusals": 0.28, + "average_num_incorrect_guesses": 0.3333, + "average_word_difficulty": 2.2333, + }, + "shortlist": { + "winrate": 1, + "score": 14.1388, + "average_num_guesses": 1.8611, + "average_num_questions": 5.8611, + "average_num_violations": 0.1944, + "average_num_gamemaster_refusals": 0.1111, + "average_num_incorrect_guesses": 0.8611, + "average_word_difficulty": 2.2777, + } +} + +UNIT_METRICS = ["winrate"] + +def extract_metrics(datadir: Path) -> pd.DataFrame: + df_rows = [] + # There are two eval variants: standard and shortlist. + for variant in os.listdir(datadir): + for path, results in sorted(list(log_utils.get_final_results_from_dir(f"{datadir}/{variant}").items())): + spec = log_utils.extract_spec(path) + solver_path = Path(spec["completion_fns"][0]) + model = solver_path.name + solver = solver_path.parent.name + # Remove root section of path, which is the eval name + solver_path = solver_path.relative_to(solver_path.parts[0]) + df_rows.append({"solver": solver, "model": model, "variant": variant, **results}) + df = pd.DataFrame(df_rows) + df.rename(columns={"accuracy": "winrate"}, inplace=True) + df.sort_values(by=["variant", "model", "solver"], inplace=True) + df.to_csv(datadir / "results.csv", index=False) + + return df + +def make_plot(df: pd.DataFrame, outpath: Path, metric="score", variant="standard"): + df = df.round(2) + plt.figure() + sns.set_theme(style="whitegrid") + + def compute_sem(x): + sem = x.std() / (len(x) ** 0.5) + sem2 = sem * 2 # 95% confidence interval + lower = max(0, (x.mean() - sem2).round(2)) + upper = (x.mean() + sem2).round(2) + return lower, upper + + + # Plotting + sns.set(style="whitegrid") + ax = sns.barplot(x=metric, y="model", hue="solver", data=df, errorbar=compute_sem, capsize=0.1) + for container in ax.containers: + ax.bar_label(container, fmt="{:.2f}", label_type="edge", padding=15) + + ax.axvline(HUMAN_BASELINE[variant][metric], color="red", linestyle="--") + + # A bunch of tweaks to make individual plots look nice. + if variant == "shortlist" and metric == "winrate": + plt.text(HUMAN_BASELINE[variant][metric] - 0.35, .5, "Human baseline", color="red", fontsize=12, ha="left") + elif variant == "standard" and metric == "average_num_questions": + plt.text(HUMAN_BASELINE[variant][metric] - 7, .5, "Human baseline", color="red", fontsize=12, ha="left") + else: + plt.text(HUMAN_BASELINE[variant][metric] + 0.05, .5, "Human baseline", color="red", fontsize=12, ha="left") + + # Some of the metrics are in [0, 1]. + if metric in UNIT_METRICS: + plt.xlim(0, 1.1) + + if metric in ("score", "average_num_questions"): + plt.xlim(0, 20.1) + + if metric == "average_word_difficulty": + plt.xlim(0, 3.1) # 6 is the maximum word difficulty in the dataset. + + if metric in ("score", "winrate"): + plt.legend(loc="lower right") + + plt.title(PLOT_TITLES_BY_METRIC[metric] + f" ({variant} variant)") + plt.xlabel(metric) + plt.tight_layout() + plt.savefig(outpath) + plt.close() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--log-dir", "-d", type=str, required=True) + parser.add_argument("--out-dir", "-o", type=str, default="./outputs") + args = parser.parse_args() + log_dir = Path(args.log_dir) + out_dir = Path(args.out_dir) + + out_dir.mkdir(exist_ok=True, parents=True) + + df = extract_metrics(log_dir) + + # Rename some of the solver values so they can be represented in the same plot. + df.loc[df['solver'] == 'cot_hhh', 'solver'] = 'cot' + df.loc[df['solver'] == 'hhh', 'solver'] = 'direct' + + for variant in df['variant'].unique(): + df_per_variant = df[df['variant'] == variant] + + print(f"Plotting all metrics for {variant} variant...") + + core_metrics = ["score", "winrate"] + auxiliary_metrics = ["average_num_guesses", "average_num_questions", "average_num_violations", "average_num_gamemaster_refusals", "average_num_incorrect_guesses", "average_word_difficulty"] + for metric in core_metrics + auxiliary_metrics: + make_plot(df_per_variant[["model", "solver", metric]].copy(), + out_dir / f"{variant}_{metric}.png", + metric, + variant) \ No newline at end of file diff --git a/evals/elsuite/twenty_questions/scripts/run_experiments.sh b/evals/elsuite/twenty_questions/scripts/run_experiments.sh new file mode 100644 index 0000000000..4b8718d607 --- /dev/null +++ b/evals/elsuite/twenty_questions/scripts/run_experiments.sh @@ -0,0 +1,60 @@ +logdir=./logs +outputdir=./outputs + +timestamp=$(date +%Y%m%d_%H%M%S) +logpathbase=$logdir/$timestamp + +num_repeats=1 + +# Check for --num_repeats argument +for arg in "$@" +do + if [[ $arg == --num_repeats=* ]]; then + num_repeats="${arg#*=}" + fi +done + +echo Num repeats is: $num_repeats +echo Running experiments and logging to $logpathbase + +declare -a SOLVERS=( + # Solvers for gpt-3.5-turbo + "generation/direct/gpt-3.5-turbo" + "twenty_questions/cot/gpt-3.5-turbo" + + # # Solvers for gpt-4-turbo-preview + "generation/direct/gpt-4-turbo-preview" + "twenty_questions/cot/gpt-4-turbo-preview" + + # # Solvers for gpt-4-base + "generation/hhh/gpt-4-base" + "twenty_questions/cot_hhh/gpt-4-base" +) + +if [ ! -d "$logpathbase/standard" ]; then + mkdir -p "$logpathbase/standard" +fi + +if [ ! -d "$logpathbase/standard" ]; then + mkdir -p "$logpathbase/shortlist" +fi + + for solver in "${SOLVERS[@]}" + do + for ((i=1;i<=num_repeats;i++)) + do + echo "Running $solver, iteration $i, standard variant." + oaieval $solver twenty_questions.full --record_path "$logpathbase/standard/$solver-$i.log" + + echo "Running $solver, iteration $i, shortlist variant." + oaieval $solver twenty_questions.shortlist.full --record_path "$logpathbase/shortlist/$solver-$i.log" + done + done +done + +echo Done running experiments, all logs in $logpathbase + +echo Producing plots, outputs to $outputdir + +# Produce results +python scripts/make_plots.py --log-dir $logpathbase --out-dir $outputdir \ No newline at end of file diff --git a/evals/elsuite/twenty_questions/test_utils.py b/evals/elsuite/twenty_questions/test_utils.py new file mode 100644 index 0000000000..069b86c8ad --- /dev/null +++ b/evals/elsuite/twenty_questions/test_utils.py @@ -0,0 +1,27 @@ +from utils import format_msg, format_msgs +from evals.task_state import Message + +def test_format_msg(): + msg = Message(content="I'm a message", role="guesser") + + assert format_msg(msg, "guesser") == Message(content="I'm a message", role="assistant") + assert format_msg(msg, "gamemaster") == Message(content="I'm a message", role="user") + +def test_format_msgs(): + msgs = [ + Message(content="I'm a guesser message", role="guesser"), + Message(content="I'm a gamemaster message", role="gamemaster"), + Message(content="I'm another guesser message", role="guesser"), + ] + + assert format_msgs(msgs, "guesser") == [ + Message(content="I'm a guesser message", role="assistant"), + Message(content="I'm a gamemaster message", role="user"), + Message(content="I'm another guesser message", role="assistant"), + ] + + assert format_msgs(msgs, "gamemaster") == [ + Message(content="I'm a guesser message", role="user"), + Message(content="I'm a gamemaster message", role="assistant"), + Message(content="I'm another guesser message", role="user"), + ] \ No newline at end of file diff --git a/evals/elsuite/twenty_questions/utils.py b/evals/elsuite/twenty_questions/utils.py new file mode 100644 index 0000000000..169dd0a977 --- /dev/null +++ b/evals/elsuite/twenty_questions/utils.py @@ -0,0 +1,69 @@ +from typing import Literal, List, Optional +from evals.task_state import Message, TaskState + +PROMPTS = { + "guesser": "You are about to play the game '20 questions'. The other player has thought of a noun which you should try to guess. You can ask 20 yes/no questions, to which they will answer 'yes', 'no', or 'skip' (if they cannot answer your question). You should try to guess the word as soon as possible, using the least amount of questions. To guess a word, type [GUESS ] – for example to guess the word 'sibling', output [GUESS sibling]. Your score will be 0 if you do not guess correctly, and {max_questions} minus the number of questions you asked if you guess correctly. Start by asking your first question.", + "guesser_shortlist": "You are about to play the game '20 questions'. The other player has thought of a noun which you should try to guess. The noun is in this list: {shortlist}.\nYou can ask 20 yes/no questions, to which they will answer 'yes', 'no', or 'skip' (if they cannot answer your question). You should try to guess the word as soon as possible, using the least amount of questions. To guess a word, type [GUESS ] – for example to guess the word 'sibling', output [GUESS sibling]. Your score will be 0 if you do not guess correctly, and {max_questions} minus the number of questions you asked if you guess correctly. Start by asking your first question.", + "gamemaster": "You are about to play the game '20 questions'. You'll be given a noun, which the other player will try to guess. They will ask 20 yes/no questions to try to guess the word. You will answer their questions with 'yes', 'no', or 'skip'. The word is: {word}.", + "incorrect_guess": "Guess {guess} is incorrect. Ask more questions, or make another guess!", + "rule_violation": "Your output was neither a guess nor a question. Try again! You can ask a yes/no question, or make a guess by outputting [GUESS ]." +} + +def generate_task_state_for(role: Literal["guesser", "gamemaster"], conversation: list[Message], word: Optional[str] = None, max_questions: int = 20, shortlist: Optional[List[str]] = None) -> TaskState: + """Generates a TaskState for the given role and conversation.""" + if role == "guesser": + prompt = PROMPTS["guesser"].format(max_questions=max_questions) if shortlist is None else PROMPTS["guesser_shortlist"].format(max_questions=max_questions, shortlist=shortlist) + elif role == "gamemaster": + prompt = PROMPTS[role].format(word=word) + else: + raise ValueError(f"Invalid role: {role}") + + formatted_conversation = format_msgs(conversation, role) + + return TaskState( + task_description=prompt, + messages=formatted_conversation, + ) + + +def format_msgs( + messages: list[Message], + role: Literal["guesser", "gamemaster"], +) -> list[Message]: + """Format messages from the perspective of the `role`.""" + new_messages = [format_msg(msg, role) for msg in messages] + + # post-conditions + for m in new_messages: + assert m.role in ["user", "assistant", "system"] + + return new_messages + +def format_msg(msg: Message, role: Literal["guesser", "gamemaster"]) -> Message: + """Formats a single message from the perspective of the `role`.""" + + # body + is_others_msg = role not in msg.role + new_content = msg.content + + if is_others_msg: + new_role = "user" + elif is_system_msg(msg): + new_role = "system" + else: + new_role = "assistant" + + new_message = Message(content=new_content, role=new_role) + + # post-conditions + assert isinstance(new_message.content, str) + assert new_message.role in ["user", "assistant", "system"] + + return new_message + +def is_system_msg(m: Message) -> bool: + assert isinstance(m, Message), "Message must be a Message type." + assert hasattr(m, "role"), "Message must have a role." + assert isinstance(m.role, str), "Message role must be a string." + + return m.role == "system" \ No newline at end of file diff --git a/evals/registry/data/already_said_that/500_100.jsonl b/evals/registry/data/already_said_that/500_100.jsonl new file mode 100644 index 0000000000..1804855f51 --- /dev/null +++ b/evals/registry/data/already_said_that/500_100.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ce6eeadef5aa33a57abcd0b42bbdb02d1a4eb3c61a35bcc1da9b1b29766f3469 +size 633725 diff --git a/evals/registry/data/already_said_that/LICENSE b/evals/registry/data/already_said_that/LICENSE new file mode 100644 index 0000000000..5a6c138028 --- /dev/null +++ b/evals/registry/data/already_said_that/LICENSE @@ -0,0 +1,3 @@ +WordNet: +WordNet License: https://wordnet.princeton.edu/license-and-commercial-use +Source: https://wordnet.princeton.edu/ diff --git a/evals/registry/data/bugged_tools/main.jsonl b/evals/registry/data/bugged_tools/main.jsonl new file mode 100644 index 0000000000..64bc2aa725 --- /dev/null +++ b/evals/registry/data/bugged_tools/main.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:418a355616ca1749a425d2ebce92a6441d62c22858d391887e1aae554db9b724 +size 257364 diff --git a/evals/registry/data/bugged_tools/main_small.jsonl b/evals/registry/data/bugged_tools/main_small.jsonl new file mode 100644 index 0000000000..c2cd4596ca --- /dev/null +++ b/evals/registry/data/bugged_tools/main_small.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d3ef5a5fadf027b15af803dca5d4c61e9abfe27c29f984a76895467de79b30a1 +size 25661 diff --git a/evals/registry/data/cant_do_that_anymore/diagonal_moves_dataset.jsonl b/evals/registry/data/cant_do_that_anymore/diagonal_moves_dataset.jsonl new file mode 100644 index 0000000000..7cce7ab588 --- /dev/null +++ b/evals/registry/data/cant_do_that_anymore/diagonal_moves_dataset.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:345340a9c74ae6d3ad73393b43986c37fa30ad2df8e94d147d9f63cf519e703e +size 540964 diff --git a/evals/registry/data/cant_do_that_anymore/gpt-3.5-turbo-0125_dataset.jsonl b/evals/registry/data/cant_do_that_anymore/gpt-3.5-turbo-0125_dataset.jsonl new file mode 100644 index 0000000000..d63a762d37 --- /dev/null +++ b/evals/registry/data/cant_do_that_anymore/gpt-3.5-turbo-0125_dataset.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:08d0cbf162d7b46e8931c74816f597085d5d365895e7f8c9f9b20d98be0566c8 +size 170427 diff --git a/evals/registry/data/cant_do_that_anymore/gpt-3.5-turbo-instruct_dataset.jsonl b/evals/registry/data/cant_do_that_anymore/gpt-3.5-turbo-instruct_dataset.jsonl new file mode 100644 index 0000000000..43161bec40 --- /dev/null +++ b/evals/registry/data/cant_do_that_anymore/gpt-3.5-turbo-instruct_dataset.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d3d9927244f61a7e00d7b4d9e5521b8ad3249be08cbf8afd3c75b30fe8f4e9a5 +size 223466 diff --git a/evals/registry/data/cant_do_that_anymore/gpt-4-0125-preview_dataset.jsonl b/evals/registry/data/cant_do_that_anymore/gpt-4-0125-preview_dataset.jsonl new file mode 100644 index 0000000000..1c693f76de --- /dev/null +++ b/evals/registry/data/cant_do_that_anymore/gpt-4-0125-preview_dataset.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:80a41ce88bab1d6b9315835fa2845bb754ed52d0d7983857f255f5de0fd2fbdb +size 283930 diff --git a/evals/registry/data/cant_do_that_anymore/gpt-4-0314_dataset.jsonl b/evals/registry/data/cant_do_that_anymore/gpt-4-0314_dataset.jsonl new file mode 100644 index 0000000000..e6dffa7d4d --- /dev/null +++ b/evals/registry/data/cant_do_that_anymore/gpt-4-0314_dataset.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5df2376c0805ea323dddec11a01d5d843edce069f86550f2a9e91efcad4f51cc +size 549365 diff --git a/evals/registry/data/cant_do_that_anymore/special_moves_dataset.jsonl b/evals/registry/data/cant_do_that_anymore/special_moves_dataset.jsonl new file mode 100644 index 0000000000..6f5e89e691 --- /dev/null +++ b/evals/registry/data/cant_do_that_anymore/special_moves_dataset.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:baea567fbd18be57a6fba31a8e7d05a670bfd86799397269aa9b47ab6d2f2a5b +size 3381675 diff --git a/evals/registry/data/error_recovery/main.jsonl b/evals/registry/data/error_recovery/main.jsonl new file mode 100644 index 0000000000..77835457c7 --- /dev/null +++ b/evals/registry/data/error_recovery/main.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8fda8fddd6a63d6b84ee4b6a8934bcedcada67e3fcd5df64041f14c04d774be3 +size 1543818 diff --git a/evals/registry/data/error_recovery/medium.jsonl b/evals/registry/data/error_recovery/medium.jsonl new file mode 100644 index 0000000000..77b989dee3 --- /dev/null +++ b/evals/registry/data/error_recovery/medium.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d5c591504d282ca7763d7abe407958da1ea06d6dc62be4808ba4fa97ff5f3cb2 +size 280075 diff --git a/evals/registry/data/error_recovery/small.jsonl b/evals/registry/data/error_recovery/small.jsonl new file mode 100644 index 0000000000..64172d3d10 --- /dev/null +++ b/evals/registry/data/error_recovery/small.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3e55b1af640b26eff5661c83c7ff6bf52040ea062c9a71ba16069e2305fdb362 +size 10191 diff --git a/evals/registry/data/function_deduction/data.jsonl b/evals/registry/data/function_deduction/data.jsonl new file mode 100644 index 0000000000..bded32c52b --- /dev/null +++ b/evals/registry/data/function_deduction/data.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cb7cd13c1f67a7be8d153de26c7436a805035053f5497b77296e3f3615023e86 +size 50468 diff --git a/evals/registry/data/identifying_variables/balanced_ctrl_vars.jsonl b/evals/registry/data/identifying_variables/balanced_ctrl_vars.jsonl new file mode 100644 index 0000000000..c29a8ee65d --- /dev/null +++ b/evals/registry/data/identifying_variables/balanced_ctrl_vars.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e9429fe712578ae4298e012cc374198bf83cf968115004dc00d24e42ebdc4f1d +size 12525123 diff --git a/evals/registry/data/identifying_variables/balanced_hypotheses.jsonl b/evals/registry/data/identifying_variables/balanced_hypotheses.jsonl new file mode 100644 index 0000000000..cb05f29c53 --- /dev/null +++ b/evals/registry/data/identifying_variables/balanced_hypotheses.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e92ee79ee832d7f6f40e55cad82fe26100ea3c1ca1faac2f606a046ef4a09b79 +size 7554989 diff --git a/evals/registry/data/incontext_rl/samples.jsonl b/evals/registry/data/incontext_rl/samples.jsonl new file mode 100644 index 0000000000..acacbee595 --- /dev/null +++ b/evals/registry/data/incontext_rl/samples.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4a675930c0b31dcee9dca9f653085f9eb2b856c1284c289ed5501d44bd94fec5 +size 4138 diff --git a/evals/registry/data/incontext_rl/samples_dev.jsonl b/evals/registry/data/incontext_rl/samples_dev.jsonl new file mode 100644 index 0000000000..110af6c8cf --- /dev/null +++ b/evals/registry/data/incontext_rl/samples_dev.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:863664b313c3c8e77e3da6ad6f0ef695e8f86ff9d1ecdd7d5fcf0d408bf464da +size 1617 diff --git a/evals/registry/data/incontext_rl/samples_gymnasium_only.jsonl b/evals/registry/data/incontext_rl/samples_gymnasium_only.jsonl new file mode 100644 index 0000000000..d0448241d0 --- /dev/null +++ b/evals/registry/data/incontext_rl/samples_gymnasium_only.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7314053ae7203d627611fadb2d5f04f2aa6b001def00047bca206d0db43cb62b +size 3455 diff --git a/evals/registry/data/skill_acquisition/miskito/knowledge_base/honduras.jsonl b/evals/registry/data/skill_acquisition/miskito/knowledge_base/honduras.jsonl new file mode 100644 index 0000000000..cc818363b7 --- /dev/null +++ b/evals/registry/data/skill_acquisition/miskito/knowledge_base/honduras.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:50b995b959aa7308a0be6413d005c0407984cf6f57a953c1fdde745f17df0db4 +size 72360 diff --git a/evals/registry/data/skill_acquisition/miskito/knowledge_base/human_rights_miskito.jsonl b/evals/registry/data/skill_acquisition/miskito/knowledge_base/human_rights_miskito.jsonl new file mode 100644 index 0000000000..fe48f48eef --- /dev/null +++ b/evals/registry/data/skill_acquisition/miskito/knowledge_base/human_rights_miskito.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a3baae4eade2acc21395c8b29a1f82cc05da00b7f7bc4cd458cc8ee2f7d032cb +size 10298 diff --git a/evals/registry/data/skill_acquisition/miskito/knowledge_base/miskito_language.jsonl b/evals/registry/data/skill_acquisition/miskito/knowledge_base/miskito_language.jsonl new file mode 100644 index 0000000000..7b118988e6 --- /dev/null +++ b/evals/registry/data/skill_acquisition/miskito/knowledge_base/miskito_language.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2972b14f1f6aa0fb4246a3d4a964cf07c0dfc3e717b6036ccff7d1f6284e7812 +size 7399 diff --git a/evals/registry/data/skill_acquisition/miskito/knowledge_base/miskito_lessons.jsonl b/evals/registry/data/skill_acquisition/miskito/knowledge_base/miskito_lessons.jsonl new file mode 100644 index 0000000000..fd42093260 --- /dev/null +++ b/evals/registry/data/skill_acquisition/miskito/knowledge_base/miskito_lessons.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f657754efc73614292b53c313583cd0013a9f7bde1e6018220d0bd15a546838c +size 43506 diff --git a/evals/registry/data/skill_acquisition/miskito/knowledge_base/miskito_people.jsonl b/evals/registry/data/skill_acquisition/miskito/knowledge_base/miskito_people.jsonl new file mode 100644 index 0000000000..eb18d39508 --- /dev/null +++ b/evals/registry/data/skill_acquisition/miskito/knowledge_base/miskito_people.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2be3a3684c1586cc0779ae4cf47866d0e88bd8f67c5256438fe59aaa2e8a81b7 +size 53928 diff --git a/evals/registry/data/skill_acquisition/miskito/knowledge_base/mosquito.jsonl b/evals/registry/data/skill_acquisition/miskito/knowledge_base/mosquito.jsonl new file mode 100644 index 0000000000..f3abc68f3e --- /dev/null +++ b/evals/registry/data/skill_acquisition/miskito/knowledge_base/mosquito.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a69fde31a05e3f95e34bcbbc7e9986e3bf107513658a6e002ae8bb303d69d7d8 +size 28786 diff --git a/evals/registry/data/skill_acquisition/miskito/knowledge_base/mosquito_coast.jsonl b/evals/registry/data/skill_acquisition/miskito/knowledge_base/mosquito_coast.jsonl new file mode 100644 index 0000000000..29e7de5a3f --- /dev/null +++ b/evals/registry/data/skill_acquisition/miskito/knowledge_base/mosquito_coast.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1539eceeb715376db2db9eb17dcc7c5e43d0e71df710c65b71a7d6276c23dc44 +size 34533 diff --git a/evals/registry/data/skill_acquisition/miskito/knowledge_base/nicaragua.jsonl b/evals/registry/data/skill_acquisition/miskito/knowledge_base/nicaragua.jsonl new file mode 100644 index 0000000000..c71d1603ef --- /dev/null +++ b/evals/registry/data/skill_acquisition/miskito/knowledge_base/nicaragua.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b446d17a582e1d8bdf2c6c46a742716c7290dc441559817c841361c3e33c39fd +size 80204 diff --git a/evals/registry/data/skill_acquisition/miskito/qa_pairs_by_lesson.jsonl b/evals/registry/data/skill_acquisition/miskito/qa_pairs_by_lesson.jsonl new file mode 100644 index 0000000000..226af11348 --- /dev/null +++ b/evals/registry/data/skill_acquisition/miskito/qa_pairs_by_lesson.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:92c631af79044257aea396b250a93eb466d404d637c6c0fc764a30763576f5ea +size 32651 diff --git a/evals/registry/data/skill_acquisition/miskito/variants/miskito_test_all.jsonl b/evals/registry/data/skill_acquisition/miskito/variants/miskito_test_all.jsonl new file mode 100644 index 0000000000..d72e8b83fa --- /dev/null +++ b/evals/registry/data/skill_acquisition/miskito/variants/miskito_test_all.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5c9540f646ea2610874b3e33286e300cd92b70d91f6c00f5b0275f1be918b74a +size 38464 diff --git a/evals/registry/data/skill_acquisition/miskito/variants/miskito_test_all_fewshot.jsonl b/evals/registry/data/skill_acquisition/miskito/variants/miskito_test_all_fewshot.jsonl new file mode 100644 index 0000000000..8114c4e111 --- /dev/null +++ b/evals/registry/data/skill_acquisition/miskito/variants/miskito_test_all_fewshot.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2af150986f257a3e358d76b5a878d17116b583332eee303f0792fcffd1eee6d1 +size 37930 diff --git a/evals/registry/data/skill_acquisition/miskito/variants/miskito_test_manipulation.jsonl b/evals/registry/data/skill_acquisition/miskito/variants/miskito_test_manipulation.jsonl new file mode 100644 index 0000000000..151136d565 --- /dev/null +++ b/evals/registry/data/skill_acquisition/miskito/variants/miskito_test_manipulation.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:51ec99e36e05dd2ee0f87f9177c4c4fc0155c744b1ed30d26bf4464ff7985e4f +size 28627 diff --git a/evals/registry/data/skill_acquisition/miskito/variants/miskito_test_manipulation_fewshot.jsonl b/evals/registry/data/skill_acquisition/miskito/variants/miskito_test_manipulation_fewshot.jsonl new file mode 100644 index 0000000000..6a5e655d82 --- /dev/null +++ b/evals/registry/data/skill_acquisition/miskito/variants/miskito_test_manipulation_fewshot.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:62b665285283e232bbd670c32900078c77380b5c3c612d3fa11b4369e007edd5 +size 28201 diff --git a/evals/registry/data/skill_acquisition/miskito/variants/miskito_test_translation.jsonl b/evals/registry/data/skill_acquisition/miskito/variants/miskito_test_translation.jsonl new file mode 100644 index 0000000000..491624d22c --- /dev/null +++ b/evals/registry/data/skill_acquisition/miskito/variants/miskito_test_translation.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:19d8e9bbe4868479d2df0f6c7e72740399db5943dde1d3109c66affe878a62d8 +size 9836 diff --git a/evals/registry/data/skill_acquisition/miskito/variants/miskito_test_translation_fewshot.jsonl b/evals/registry/data/skill_acquisition/miskito/variants/miskito_test_translation_fewshot.jsonl new file mode 100644 index 0000000000..7e91554f6f --- /dev/null +++ b/evals/registry/data/skill_acquisition/miskito/variants/miskito_test_translation_fewshot.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2a76625921e1810e4ce22ba76f4881ee9327e1522555cc9eccc6beb854b7a129 +size 9236 diff --git a/evals/registry/data/skill_acquisition/miskito/variants/miskito_train_all.jsonl b/evals/registry/data/skill_acquisition/miskito/variants/miskito_train_all.jsonl new file mode 100644 index 0000000000..17d04c41d2 --- /dev/null +++ b/evals/registry/data/skill_acquisition/miskito/variants/miskito_train_all.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f099000b117f1d5d46778263674998e9785f3e993b4c32ca8afb5f82065e1afb +size 560 diff --git a/evals/registry/data/skill_acquisition/miskito/variants/miskito_train_manipulation.jsonl b/evals/registry/data/skill_acquisition/miskito/variants/miskito_train_manipulation.jsonl new file mode 100644 index 0000000000..480e645ad3 --- /dev/null +++ b/evals/registry/data/skill_acquisition/miskito/variants/miskito_train_manipulation.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:625f3fa618e52688f6593774b7ba5691879882dbe9e3a8508a8aed43327f7d86 +size 425 diff --git a/evals/registry/data/skill_acquisition/miskito/variants/miskito_train_translation.jsonl b/evals/registry/data/skill_acquisition/miskito/variants/miskito_train_translation.jsonl new file mode 100644 index 0000000000..4f83ef3e2d --- /dev/null +++ b/evals/registry/data/skill_acquisition/miskito/variants/miskito_train_translation.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8c2f3a4699303b85d49641ec17ec77bff600c75940078d735406db1539da90c4 +size 599 diff --git a/evals/registry/data/twenty_questions/LICENSE b/evals/registry/data/twenty_questions/LICENSE new file mode 100644 index 0000000000..7b971d365d --- /dev/null +++ b/evals/registry/data/twenty_questions/LICENSE @@ -0,0 +1,3 @@ +lexical_simplification: +MIT License: https://opensource.org/licenses/MIT +Source: https://github.com/mounicam/lexical_simplification \ No newline at end of file diff --git a/evals/registry/data/twenty_questions/dataset.jsonl b/evals/registry/data/twenty_questions/dataset.jsonl new file mode 100644 index 0000000000..ea11e6a68e --- /dev/null +++ b/evals/registry/data/twenty_questions/dataset.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8a8358c42ef70c2c48c6bb2e214787e968cd1b092daeb1dd572f942bd7146bff +size 7664 diff --git a/evals/registry/data/twenty_questions/lexicon_nouns.jsonl b/evals/registry/data/twenty_questions/lexicon_nouns.jsonl new file mode 100644 index 0000000000..869a13feb1 --- /dev/null +++ b/evals/registry/data/twenty_questions/lexicon_nouns.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:754d1f85637de87dac8aadfa5163f073d65289f27677031e765334c786742171 +size 112218 diff --git a/evals/registry/evals/already_said_that.yaml b/evals/registry/evals/already_said_that.yaml new file mode 100644 index 0000000000..2895544ffe --- /dev/null +++ b/evals/registry/evals/already_said_that.yaml @@ -0,0 +1,50 @@ +already_said_that: + id: already_said_that.reverse-sort-words-eng + metrics: + [ + "avg_num_turns", + "stddev_num_turns", + "median_num_turns", + "max_num_turns", + "min_num_turns", + "false_positive_rate", + "false_negative_rate", + "avg_distractor_accuracy", + "violation_rate", + "avg_num_distractors", + "stddev_num_distractors", + "median_num_distractors", + "max_num_distractors", + "min_num_distractors", + ] + description: "Sustain performance in the presence of distractors" + +already_said_that.which-is-heavier: + class: evals.elsuite.already_said_that.eval:AlreadySaidThat + args: + samples_jsonl: already_said_that/500_100.jsonl + distractor_variant: which-is-heavier + +already_said_that.first-letters: + class: evals.elsuite.already_said_that.eval:AlreadySaidThat + args: + samples_jsonl: already_said_that/500_100.jsonl + distractor_variant: first-letters + +already_said_that.ambiguous-sentences: + class: evals.elsuite.already_said_that.eval:AlreadySaidThat + args: + samples_jsonl: already_said_that/500_100.jsonl + distractor_variant: ambiguous-sentences + +already_said_that.reverse-sort-words-eng: + class: evals.elsuite.already_said_that.eval:AlreadySaidThat + args: + samples_jsonl: already_said_that/500_100.jsonl + distractor_variant: reverse-sort-words-eng + +already_said_that.distractorless: + class: evals.elsuite.already_said_that.eval:AlreadySaidThat + args: + samples_jsonl: already_said_that/500_100.jsonl + distractor_variant: distractorless diff --git a/evals/registry/evals/bugged_tools.yaml b/evals/registry/evals/bugged_tools.yaml new file mode 100644 index 0000000000..ff63e87321 --- /dev/null +++ b/evals/registry/evals/bugged_tools.yaml @@ -0,0 +1,31 @@ +bugged_tools: + id: bugged_tools.all + metrics: [f1, precision, recall, accuracy] + description: Evaluates ability to identify bugs in tools + +bugged_tools.all: + class: evals.elsuite.bugged_tools.eval:BuggedTools + args: + samples_jsonl: bugged_tools/main.jsonl + max_turns: 10 + log_all_metrics: False + use_judge: True + bug_instructions_type: simple_warning + +bugged_tools.all_log: + class: evals.elsuite.bugged_tools.eval:BuggedTools + args: + samples_jsonl: bugged_tools/main.jsonl + max_turns: 10 + log_all_metrics: True + use_judge: True + bug_instructions_type: simple_warning + +bugged_tools.all_small: + class: evals.elsuite.bugged_tools.eval:BuggedTools + args: + samples_jsonl: bugged_tools/main_small.jsonl + max_turns: 10 + log_all_metrics: False + use_judge: True + bug_instructions_type: simple_warning diff --git a/evals/registry/evals/cant_do_that_anymore.yaml b/evals/registry/evals/cant_do_that_anymore.yaml new file mode 100644 index 0000000000..d7254a9545 --- /dev/null +++ b/evals/registry/evals/cant_do_that_anymore.yaml @@ -0,0 +1,23 @@ +cant_do_that_anymore: + id: cant_do_that_anymore.all + metrics: [variant_impact_factor, delta, predicted_move_proportion, predicted_move_in_variant_proportion, avg_num_previous_moves, std_num_previous_moves] + description: Evaluates how well models can adapt to new rules of an environment (chess) + +cant_do_that_anymore.all: + class: evals.elsuite.cant_do_that_anymore.eval:CantDoThatAnymore + args: + default_model_dataset: "gpt-3.5-turbo-0125" + n_samples: 1000 + +cant_do_that_anymore.all_small: + class: evals.elsuite.cant_do_that_anymore.eval:CantDoThatAnymore + args: + default_model_dataset: "gpt-3.5-turbo-0125" + n_samples: 100 + +cant_do_that_anymore.all_diagonal: + class: evals.elsuite.cant_do_that_anymore.eval:CantDoThatAnymore + args: + default_model_dataset: "gpt-3.5-turbo-0125" + n_samples: 1000 + diagonal_variation: True diff --git a/evals/registry/evals/error_recovery.yaml b/evals/registry/evals/error_recovery.yaml new file mode 100644 index 0000000000..f42e0e9243 --- /dev/null +++ b/evals/registry/evals/error_recovery.yaml @@ -0,0 +1,36 @@ +error-recovery: + id: error-recovery.main + metrics: [accuracy] + description: TODO + +error-recovery.main: + class: evals.elsuite.error_recovery.eval:ErrorRecovery + args: + samples_jsonl: error_recovery/main.jsonl + +error-recovery.small: + class: evals.elsuite.error_recovery.eval:ErrorRecovery + args: + samples_jsonl: error_recovery/small.jsonl + +error-recovery.medium: + class: evals.elsuite.error_recovery.eval:ErrorRecovery + args: + samples_jsonl: error_recovery/medium.jsonl + +# --- mark reasoning as 'user' variant --- +error-recovery.main.other-reasoning: + class: evals.elsuite.error_recovery.eval:ErrorRecovery + args: + samples_jsonl: error_recovery/main.jsonl + mark_as_own_reasoning: False +error-recovery.small.other-reasoning: + class: evals.elsuite.error_recovery.eval:ErrorRecovery + args: + samples_jsonl: error_recovery/small.jsonl + mark_as_own_reasoning: False +error-recovery.medium.other-reasoning: + class: evals.elsuite.error_recovery.eval:ErrorRecovery + args: + samples_jsonl: error_recovery/medium.jsonl + mark_as_own_reasoning: False diff --git a/evals/registry/evals/function-deduction.yaml b/evals/registry/evals/function-deduction.yaml new file mode 100644 index 0000000000..337856cd72 --- /dev/null +++ b/evals/registry/evals/function-deduction.yaml @@ -0,0 +1,37 @@ +function_deduction: + id: function_deduction.easy + metrics: [adjusted_avg_rounds, solved_ratio, solved, samples, avg_success_rounds, avg_sample_rounds_std_adjusted, avg_sample_rounds_std_no_failed, solved_ratio_if_any_solved, avg_ask_rounds, avg_guess_rounds, avg_incorrect_format_rounds, solved_avg_complexity, not_solved_avg_complexity, solved_or_not_mann_whitney_u_p_value, sem_adjusted_avg_rounds, sem_avg_success_rounds, sem_avg_guess_rounds, sem_avg_incorrect_format_rounds] + description: Test a model's ability to deduce unknown functions + +function_deduction.easy: + class: evals.elsuite.function_deduction.eval:FunctionDeductionEval + args: + mode: easy + n_rounds: 20 + +function_deduction.easy.long: + class: evals.elsuite.function_deduction.eval:FunctionDeductionEval + args: + mode: easy + n_rounds: 20 + n_repeat: 10 + +function_deduction.easy.dev5: + class: evals.elsuite.function_deduction.eval:FunctionDeductionEval + args: + mode: easy + n_rounds: 20 + n_samples: 5 + +function_deduction.hard: + class: evals.elsuite.function_deduction.eval:FunctionDeductionEval + args: + mode: hard + n_rounds: 20 + +function_deduction.hard.dev5: + class: evals.elsuite.function_deduction.eval:FunctionDeductionEval + args: + mode: hard + n_rounds: 20 + n_samples: 5 diff --git a/evals/registry/evals/identifying_variables.yaml b/evals/registry/evals/identifying_variables.yaml new file mode 100644 index 0000000000..32f4ecbafa --- /dev/null +++ b/evals/registry/evals/identifying_variables.yaml @@ -0,0 +1,136 @@ +identifying_variables: + id: identifying_variables.language-corrset.balanced-ctrl + metrics: + [ + "ctrl_nDCG", + "ctrl_recall", + "ctrl_fallout", + "hyp_valid_acc", + "ind_acc", + "dep_acc", + "violation_rate", + ] + description: + "Evaluate the model's ability of identifying the right experimental + variables for testing a given hypothesis." + +# Balanced-hypotheses datasets + +identifying_variables.markdown.balanced-hypotheses: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_hypotheses.jsonl + n_samples: 500 + renderer: markdown +identifying_variables.markdown.balanced-hypotheses-large: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_hypotheses.jsonl + renderer: markdown + group_metrics: true + +identifying_variables.csv.balanced-hypotheses: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_hypotheses.jsonl + n_samples: 500 + renderer: csv +identifying_variables.csv.balanced-hypotheses-large: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_hypotheses.jsonl + renderer: csv + group_metrics: true + +identifying_variables.json.balanced-hypotheses: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_hypotheses.jsonl + n_samples: 500 + renderer: json +identifying_variables.json.balanced-hypotheses-large: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_hypotheses.jsonl + renderer: json + group_metrics: true + +identifying_variables.language-tabular.balanced-hypotheses: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_hypotheses.jsonl + n_samples: 500 + renderer: language-tabular +identifying_variables.language-tabular.balanced-hypotheses-large: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_hypotheses.jsonl + renderer: language-tabular + group_metrics: true + +identifying_variables.language-corrset.balanced-hypotheses: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_hypotheses.jsonl + n_samples: 500 + renderer: language-corrset +identifying_variables.language-corrset.balanced-hypotheses-large: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_hypotheses.jsonl + renderer: language-corrset + group_metrics: true + +identifying_variables.corrset.balanced-hypotheses: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_hypotheses.jsonl + n_samples: 500 + renderer: corrset +identifying_variables.corrset.balanced-hypotheses-large: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_hypotheses.jsonl + renderer: corrset + group_metrics: true + +# Balanced-control datasets + +identifying_variables.csv.balanced-ctrl: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_ctrl_vars.jsonl + n_samples: 500 + renderer: csv +identifying_variables.csv.balanced-ctrl-large: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_ctrl_vars.jsonl + renderer: csv + group_metrics: true + +identifying_variables.language-corrset.balanced-ctrl: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_ctrl_vars.jsonl + n_samples: 500 + renderer: language-corrset +identifying_variables.language-corrset.balanced-ctrl-large: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_ctrl_vars.jsonl + renderer: language-corrset + group_metrics: true + +identifying_variables.corrset.balanced-ctrl: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_ctrl_vars.jsonl + n_samples: 500 + renderer: corrset +identifying_variables.corrset.balanced-ctrl-large: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_ctrl_vars.jsonl + renderer: corrset + group_metrics: true diff --git a/evals/registry/evals/incontext_rl.yaml b/evals/registry/evals/incontext_rl.yaml new file mode 100644 index 0000000000..e66f358569 --- /dev/null +++ b/evals/registry/evals/incontext_rl.yaml @@ -0,0 +1,62 @@ +incontext_rl: + id: incontext_rl.gymnasium.v0 + metrics: [] + +incontext_rl.v0: + class: evals.elsuite.incontext_rl.eval:InContextRl + args: + use_explanations: True + samples_jsonl: incontext_rl/samples.jsonl + +incontext_rl.raw.v0: + class: evals.elsuite.incontext_rl.eval:InContextRl + args: + use_explanations: False + samples_jsonl: incontext_rl/samples.jsonl + +incontext_rl.gymnasium.v0: + class: evals.elsuite.incontext_rl.eval:InContextRl + args: + use_explanations: True + samples_jsonl: incontext_rl/samples_gymnasium_only.jsonl + +incontext_rl.gymnasium.raw.v0: + class: evals.elsuite.incontext_rl.eval:InContextRl + args: + use_explanations: False + samples_jsonl: incontext_rl/samples_gymnasium_only.jsonl + +incontext_rl.short.v0: + class: evals.elsuite.incontext_rl.eval:InContextRl + args: + use_explanations: True + max_steps: 100 + samples_jsonl: incontext_rl/samples.jsonl + +incontext_rl.raw.short.v0: + class: evals.elsuite.incontext_rl.eval:InContextRl + args: + use_explanations: False + max_steps: 100 + samples_jsonl: incontext_rl/samples.jsonl + +incontext_rl.gymnasium.short.v0: + class: evals.elsuite.incontext_rl.eval:InContextRl + args: + use_explanations: True + max_steps: 100 + samples_jsonl: incontext_rl/samples_gymnasium_only.jsonl + +incontext_rl.gymnasium.raw.short.v0: + class: evals.elsuite.incontext_rl.eval:InContextRl + args: + use_explanations: False + max_steps: 100 + samples_jsonl: incontext_rl/samples_gymnasium_only.jsonl + +incontext_rl.dev.v0: + class: evals.elsuite.incontext_rl.eval:InContextRl + args: + use_explanations: True + max_steps: 5 + samples_jsonl: incontext_rl/samples.jsonl \ No newline at end of file diff --git a/evals/registry/evals/skill_acquisition.yaml b/evals/registry/evals/skill_acquisition.yaml new file mode 100644 index 0000000000..2b4594efff --- /dev/null +++ b/evals/registry/evals/skill_acquisition.yaml @@ -0,0 +1,107 @@ +# --------------- +# Miskito dataset +# --------------- + +skill_acquisition.miskito: + id: skill_acquisition.miskito.zero_shot.full + metrics: [delta_accuracy] + description: Evaluates whether models can learn the Miskito language through retrieval. + +# Miskito manipulation + translation dataset, zero- and few-shot. +skill_acquisition.miskito.zero_shot.full: + class: evals.elsuite.skill_acquisition.eval:SkillAcquisition + args: + samples_jsonl: skill_acquisition/miskito/variants/miskito_test_all.jsonl + target_language: miskito + knowledge_base_directory: skill_acquisition/miskito/knowledge_base/ + max_replies: 30 +skill_acquisition.miskito.zero_shot.dev5: + class: evals.elsuite.skill_acquisition.eval:SkillAcquisition + args: + samples_jsonl: skill_acquisition/miskito/variants/miskito_test_all.jsonl + target_language: miskito + n_samples: 5 + knowledge_base_directory: skill_acquisition/miskito/knowledge_base/ + max_replies: 30 + +skill_acquisition.miskito.few_shot.full: + class: evals.elsuite.skill_acquisition.eval:SkillAcquisition + args: + samples_jsonl: skill_acquisition/miskito/variants/miskito_test_all_fewshot.jsonl + target_language: miskito + knowledge_base_directory: skill_acquisition/miskito/knowledge_base/ + max_replies: 30 +skill_acquisition.miskito.few_shot.dev5: + class: evals.elsuite.skill_acquisition.eval:SkillAcquisition + args: + samples_jsonl: skill_acquisition/miskito/variants/miskito_test_all_fewshot.jsonl + target_language: miskito + n_samples: 5 + knowledge_base_directory: skill_acquisition/miskito/knowledge_base/ + max_replies: 30 + +# Miskito translation-only, zero- and few-shot. +skill_acquisition.miskito.zero_shot.translation.full: + class: evals.elsuite.skill_acquisition.eval:SkillAcquisition + args: + samples_jsonl: skill_acquisition/miskito/variants/miskito_test_translation.jsonl + target_language: miskito + knowledge_base_directory: skill_acquisition/miskito/knowledge_base/ + max_replies: 30 +skill_acquisition.miskito.zero_shot.translation.dev5: + class: evals.elsuite.skill_acquisition.eval:SkillAcquisition + args: + samples_jsonl: skill_acquisition/miskito/variants/miskito_test_translation.jsonl + target_language: miskito + n_samples: 5 + knowledge_base_directory: skill_acquisition/miskito/knowledge_base/ + max_replies: 30 + +skill_acquisition.miskito.few_shot.translation.full: + class: evals.elsuite.skill_acquisition.eval:SkillAcquisition + args: + samples_jsonl: skill_acquisition/miskito/variants/miskito_test_translation_fewshot.jsonl + target_language: miskito + knowledge_base_directory: skill_acquisition/miskito/knowledge_base/ + max_replies: 30 +skill_acquisition.miskito.few_shot.translation.dev5: + class: evals.elsuite.skill_acquisition.eval:SkillAcquisition + args: + samples_jsonl: skill_acquisition/miskito/variants/miskito_test_translation_fewshot.jsonl + target_language: miskito + n_samples: 5 + knowledge_base_directory: skill_acquisition/miskito/knowledge_base/ + max_replies: 30 + +# Miskito manipulation-only, zero- and few-shot. +skill_acquisition.miskito.zero_shot.manipulation.full: + class: evals.elsuite.skill_acquisition.eval:SkillAcquisition + args: + samples_jsonl: skill_acquisition/miskito/variants/miskito_test_manipulation.jsonl + target_language: miskito + knowledge_base_directory: skill_acquisition/miskito/knowledge_base/ + max_replies: 30 +skill_acquisition.miskito.zero_shot.manipulation.dev5: + class: evals.elsuite.skill_acquisition.eval:SkillAcquisition + args: + samples_jsonl: skill_acquisition/miskito/variants/miskito_test_manipulation.jsonl + target_language: miskito + n_samples: 5 + knowledge_base_directory: skill_acquisition/miskito/knowledge_base/ + max_replies: 30 + +skill_acquisition.miskito.few_shot.manipulation.full: + class: evals.elsuite.skill_acquisition.eval:SkillAcquisition + args: + samples_jsonl: skill_acquisition/miskito/variants/miskito_test_manipulation_fewshot.jsonl + target_language: miskito + knowledge_base_directory: skill_acquisition/miskito/knowledge_base/ + max_replies: 30 +skill_acquisition.miskito.few_shot.manipulation.dev5: + class: evals.elsuite.skill_acquisition.eval:SkillAcquisition + args: + samples_jsonl: skill_acquisition/miskito/variants/miskito_test_manipulation_fewshot.jsonl + target_language: miskito + n_samples: 5 + knowledge_base_directory: skill_acquisition/miskito/knowledge_base/ + max_replies: 30 \ No newline at end of file diff --git a/evals/registry/evals/track_the_stat.yaml b/evals/registry/evals/track_the_stat.yaml new file mode 100644 index 0000000000..c64ce0ed40 --- /dev/null +++ b/evals/registry/evals/track_the_stat.yaml @@ -0,0 +1,22 @@ +track_the_stat: + id: track_the_stat.mode + metrics: + [ + "avg_max_length", + "stddev_max_length", + "median_max_length", + "max_max_length", + "min_max_length", + "violation_rate", + ] + description: "Perform a sequential task by keeping track of state implicitly" + +track_the_stat.mode: + class: evals.elsuite.track_the_stat.eval:TrackTheStat + args: + task: mode + +track_the_stat.median: + class: evals.elsuite.track_the_stat.eval:TrackTheStat + args: + task: median diff --git a/evals/registry/evals/twenty_questions.yaml b/evals/registry/evals/twenty_questions.yaml new file mode 100644 index 0000000000..af3491ffcf --- /dev/null +++ b/evals/registry/evals/twenty_questions.yaml @@ -0,0 +1,60 @@ +twenty_questions: + id: twenty_questions.full + description: Tests models on the 20 questions game. + metrics: [score, accuracy, average_num_guesses, average_num_questions, average_num_violations, average_num_gamemaster_refusals, average_num_incorrect_guesses, average_word_difficulty] + +twenty_questions.full: + class: evals.elsuite.twenty_questions.eval:TwentyQuestions + args: + samples_jsonl: twenty_questions/dataset.jsonl + gamemaster_spec: twenty_questions/gamemaster/gpt-4-turbo-preview + max_questions: 20 + max_replies: 40 + +twenty_questions.shortlist.full: + class: evals.elsuite.twenty_questions.eval:TwentyQuestions + args: + samples_jsonl: twenty_questions/dataset.jsonl + gamemaster_spec: twenty_questions/gamemaster/gpt-4-turbo-preview + shortlist_variant: True + max_questions: 20 + max_replies: 40 + +twenty_questions.dev5: + class: evals.elsuite.twenty_questions.eval:TwentyQuestions + args: + samples_jsonl: twenty_questions/dataset.jsonl + gamemaster_spec: twenty_questions/gamemaster/gpt-4-turbo-preview + n_samples: 5 + max_questions: 20 + max_replies: 40 + +twenty_questions.shortlist.dev5: + class: evals.elsuite.twenty_questions.eval:TwentyQuestions + args: + samples_jsonl: twenty_questions/dataset.jsonl + gamemaster_spec: twenty_questions/gamemaster/gpt-4-turbo-preview + n_samples: 5 + shortlist_variant: True + num_shortlist_items: 5 + max_questions: 20 + max_replies: 40 + +twenty_questions.dev100: + class: evals.elsuite.twenty_questions.eval:TwentyQuestions + args: + samples_jsonl: twenty_questions/dataset.jsonl + gamemaster_spec: twenty_questions/gamemaster/gpt-4-turbo-preview + n_samples: 100 + max_questions: 20 + max_replies: 40 + +twenty_questions.shortlist.dev100: + class: evals.elsuite.twenty_questions.eval:TwentyQuestions + args: + samples_jsonl: twenty_questions/dataset.jsonl + gamemaster_spec: twenty_questions/gamemaster/gpt-4-turbo-preview + n_samples: 100 + shortlist_variant: True + max_questions: 20 + max_replies: 40 diff --git a/evals/registry/solvers/already_said_that.yaml b/evals/registry/solvers/already_said_that.yaml new file mode 100644 index 0000000000..71f0b65ff2 --- /dev/null +++ b/evals/registry/solvers/already_said_that.yaml @@ -0,0 +1,79 @@ +already_said_that/random_baseline: + class: evals.elsuite.already_said_that.solvers:RandomBaselineSolver + +already_said_that/human_cli: + class: evals.elsuite.already_said_that.solvers:AlreadySaidThatHuman + args: + human_cli_solver: + class: evals.solvers.human_cli_solver:HumanCliSolver + args: + registry: null + +already_said_that/cot/gpt-3.5-turbo: + class: evals.solvers.nested.cot_solver:CoTSolver + args: + persistent_memory: False + cot_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-3.5-turbo + extra_options: + temperature: 1 + max_tokens: 512 + extract_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-3.5-turbo + extra_options: + temperature: 1 + max_tokens: 512 + +already_said_that/cot/gpt-4-turbo-preview: + class: evals.solvers.nested.cot_solver:CoTSolver + args: + persistent_memory: False + cot_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-turbo-preview + extra_options: + temperature: 1 + max_tokens: 512 + extract_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-turbo-preview + extra_options: + temperature: 1 + max_tokens: 512 + +already_said_that/cot_hhh/gpt-4-base: + class: evals.solvers.nested.cot_solver:CoTSolver + args: + persistent_memory: False + cot_solver: + class: evals.solvers.nested.hhh_solver:HHHSolver + args: + solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-base + extra_options: + temperature: 1 + max_tokens: 512 + extract_solver: + class: evals.solvers.nested.hhh_solver:HHHSolver + args: + solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-base + extra_options: + temperature: 1 + max_tokens: 512 diff --git a/evals/registry/solvers/cant_do_that_anymore.yaml b/evals/registry/solvers/cant_do_that_anymore.yaml new file mode 100644 index 0000000000..951dd066bf --- /dev/null +++ b/evals/registry/solvers/cant_do_that_anymore.yaml @@ -0,0 +1,17 @@ +chess/generation/direct/gpt-3.5-turbo-instruct: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-3.5-turbo-instruct + extra_options: + temperature: 1 + max_tokens: 4 + +chess/generation/direct/gpt-4-base: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-base + extra_options: + temperature: 1 + max_tokens: 4 diff --git a/evals/registry/solvers/error_recovery.yaml b/evals/registry/solvers/error_recovery.yaml new file mode 100644 index 0000000000..bef801549e --- /dev/null +++ b/evals/registry/solvers/error_recovery.yaml @@ -0,0 +1,38 @@ +# TODO: use default solvers once they are versioned +error_recovery/gpt-3.5-turbo-0613: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-3.5-turbo-0613 + +error_recovery/gpt-4-0613: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-0613 + +error_recovery/default/gpt-4-base: + class: evals.solvers.nested.hhh_solver:HHHSolver + args: + solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-base + extra_options: + temperature: 1 + max_tokens: 512 + +# solver that continues the previous message +error_recovery/continue/gpt-4-base: + class: evals.solvers.nested.hhh_solver:HHHSolver + args: + solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + continue_last_assistant_msg: True + completion_fn_options: + model: gpt-4-base + extra_options: + temperature: 1 + max_tokens: 512 diff --git a/evals/registry/solvers/function_deduction.yaml b/evals/registry/solvers/function_deduction.yaml new file mode 100644 index 0000000000..9b8837851b --- /dev/null +++ b/evals/registry/solvers/function_deduction.yaml @@ -0,0 +1,192 @@ +# OS CHAIN OF THOUGHT +function_deduction/cot/llama-2-13b-chat: + class: evals.elsuite.function_deduction.solvers:CustomCoT + args: + cot_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: meta-llama/Llama-2-13b-chat-hf + extra_options: + temperature: 1 + max_tokens: 512 + extract_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: meta-llama/Llama-2-13b-chat-hf + extra_options: + temperature: 0 + max_tokens: 32 + +function_deduction/cot/llama-2-70b-chat: + class: evals.elsuite.function_deduction.solvers:CustomCoT + args: + cot_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: meta-llama/Llama-2-70b-chat-hf + extra_options: + temperature: 1 + max_tokens: 512 + extract_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: meta-llama/Llama-2-70b-chat-hf + extra_options: + temperature: 0 + max_tokens: 32 + +function_deduction/cot/mixtral-8x7b-instruct: + class: evals.elsuite.function_deduction.solvers:CustomCoT + args: + cot_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: mistralai/Mixtral-8x7B-Instruct-v0.1 + extra_options: + temperature: 1 + max_tokens: 512 + extract_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: mistralai/Mixtral-8x7B-Instruct-v0.1 + extra_options: + temperature: 0 + max_tokens: 32 + + +# CUSTOM CHAIN OF THOUGHT +function_deduction/cot/gpt-4-1106-preview: + class: evals.elsuite.function_deduction.solvers:CustomCoT + args: + cot_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-1106-preview + extra_options: + temperature: 1 + max_tokens: 512 + extract_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-1106-preview + extra_options: + temperature: 0 + max_tokens: 32 + +function_deduction/cot/gpt-4-32k: + class: evals.elsuite.function_deduction.solvers:CustomCoT + args: + cot_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-32k + extra_options: + temperature: 1 + max_tokens: 512 + extract_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-32k + extra_options: + temperature: 0 + max_tokens: 32 + +function_deduction/cot/gpt-3.5-turbo-16k: + class: evals.elsuite.function_deduction.solvers:CustomCoT + args: + cot_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-3.5-turbo-16k + extra_options: + temperature: 1 + max_tokens: 512 + extract_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-3.5-turbo-16k + extra_options: + temperature: 0 + max_tokens: 32 + +function_deduction/cot/gemini-pro: + class: evals.elsuite.function_deduction.solvers:CustomCoT + args: + cot_solver: + class: evals.solvers.providers.google.gemini_solver:GeminiSolver + args: + model_name: gemini-pro + extract_solver: + class: evals.solvers.providers.google.gemini_solver:GeminiSolver + args: + model_name: gemini-pro + + +# BASE MODELS +function_deduction/gpt-4-base: + class: evals.elsuite.function_deduction.solvers:BaseModelSolver + args: + solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-base + extra_options: + temperature: 1 + max_tokens: 32 + +function_deduction/cot/gpt-4-base: + class: evals.elsuite.function_deduction.solvers:BaseModelCoTSolver + args: + cot_solver: + class: evals.solvers.nested.hhh_solver:HHHSolver + args: + solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-base + extra_options: + temperature: 1 + max_tokens: 512 + fixed_start: "Let's think step by step. " + extract_solver: + class: evals.solvers.nested.hhh_solver:HHHSolver + args: + solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-base + extra_options: + temperature: 0 + max_tokens: 32 + + +# BASELINES +function_deduction/average_baseline: + class: evals.elsuite.function_deduction.baselines:AverageBaseline + +function_deduction/full_knowledge_random: + class: evals.elsuite.function_deduction.baselines:FullKnowledge + args: + mode: random + samples_jsonl: function_deduction/data.jsonl + +function_deduction/full_knowledge_best: + class: evals.elsuite.function_deduction.baselines:FullKnowledge + args: + mode: best + samples_jsonl: function_deduction/data.jsonl diff --git a/evals/registry/solvers/identifying_variables.yaml b/evals/registry/solvers/identifying_variables.yaml new file mode 100644 index 0000000000..aa6108febc --- /dev/null +++ b/evals/registry/solvers/identifying_variables.yaml @@ -0,0 +1,5 @@ +identifying_variables/random: + class: evals.elsuite.identifying_variables.solvers:RandomSolver + +identifying_variables/noctrl: + class: evals.elsuite.identifying_variables.solvers:NoCtrl diff --git a/evals/registry/solvers/incontext_rl.yaml b/evals/registry/solvers/incontext_rl.yaml new file mode 100644 index 0000000000..e374f2e75d --- /dev/null +++ b/evals/registry/solvers/incontext_rl.yaml @@ -0,0 +1,27 @@ +incontext_rl/random: + class: evals.elsuite.incontext_rl.baselines:RandomSolver + +incontext_rl/q-learning: + class: evals.elsuite.incontext_rl.baselines:QlearningSolver + +incontext_rl/anti-cot/gpt-3.5-turbo: + class: evals.elsuite.incontext_rl.anti-cot_solver:AntiCoTSolver + args: + solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-3.5-turbo + extra_options: + temperature: 1 + +incontext_rl/anti-cot/gpt-4-turbo-preview: + class: evals.elsuite.incontext_rl.anti-cot_solver:AntiCoTSolver + args: + solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-turbo-preview + extra_options: + temperature: 1 \ No newline at end of file diff --git a/evals/registry/solvers/skill_acquisition.yaml b/evals/registry/solvers/skill_acquisition.yaml new file mode 100644 index 0000000000..e187837297 --- /dev/null +++ b/evals/registry/solvers/skill_acquisition.yaml @@ -0,0 +1,287 @@ +# CoT solvers with a custom extraction prompt. +skill_acquisition/cot/gpt-3.5-turbo: + class: evals.solvers.nested.cot_solver:CoTSolver + args: + cot_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-3.5-turbo + extra_options: + temperature: 1 + max_tokens: 512 + extract_template: &extract_template Given the above reasoning, what is the next action you wish to take? Please respond in the format required by the instructions. + extract_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-3.5-turbo + extra_options: + temperature: 1 + max_tokens: 512 + +skill_acquisition/cot/gpt-4-turbo-preview: + class: evals.solvers.nested.cot_solver:CoTSolver + args: + cot_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-turbo-preview + extra_options: + temperature: 1 + max_tokens: 512 + extract_template: *extract_template + extract_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-turbo-preview + extra_options: + temperature: 1 + max_tokens: 512 + +skill_acquisition/cot/gemini-pro: + class: evals.solvers.nested.cot_solver:CoTSolver + args: + cot_solver: + class: evals.solvers.providers.google.gemini_solver:GeminiSolver + args: + model_name: gemini-pro + extract_template: *extract_template + extract_solver: + class: evals.solvers.providers.google.gemini_solver:GeminiSolver + args: + model_name: gemini-pro + +skill_acquisition/cot/gpt-4: + class: evals.solvers.nested.cot_solver:CoTSolver + args: + cot_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4 + extra_options: + temperature: 1 + max_tokens: 512 + extract_template: *extract_template + extract_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4 + extra_options: + temperature: 1 + max_tokens: 512 + +skill_acquisition/cot_hhh/gpt-4-base: + class: evals.solvers.nested.cot_solver:CoTSolver + args: + cot_solver: + class: evals.solvers.nested.hhh_solver:HHHSolver + args: + solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-base + extra_options: + temperature: 1 + max_tokens: 512 + extract_template: *extract_template + extract_solver: + class: evals.solvers.nested.hhh_solver:HHHSolver + args: + solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-base + extra_options: + temperature: 1 + max_tokens: 512 + +skill_acquisition/assistants/gpt-4-turbo-preview: + class: evals.elsuite.skill_acquisition.solvers:SkillAcquisitionAssistantsSolver + args: + tools: + - type: code_interpreter + - type: retrieval + model: gpt-4-turbo-preview + +skill_acquisition/cot_assistant/gpt-4-turbo-preview: + class: evals.solvers.nested.cot_solver:CoTSolver + args: + cot_solver: + class: evals.elsuite.skill_acquisition.solvers:SkillAcquisitionAssistantsSolver + args: + tools: + - type: code_interpreter + - type: retrieval + model: gpt-4-turbo-preview + extract_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-turbo-preview + extra_options: + temperature: 1 + max_tokens: 512 + +### Few-shot solvers. +# TODO: refactor few-shot solver so that train_jsonl is not parameterised here to reduce verbosity. +# Miskito full. +miskito_all/fewshot_direct/gpt-3.5-turbo: + class: evals.solvers.nested.fewshot_solver:FewShotSolver + args: + train_jsonl: evals/registry/data/skill_acquisition/miskito/variants/miskito_train_all.jsonl + n_shots: 3 + base_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-3.5-turbo + extra_options: + temperature: 1 + max_tokens: 512 + +miskito_all/fewshot_direct/gpt-4-turbo-preview: + class: evals.solvers.nested.fewshot_solver:FewShotSolver + args: + train_jsonl: evals/registry/data/skill_acquisition/miskito/variants/miskito_train_all.jsonl + n_shots: 3 + base_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-turbo-preview + extra_options: + temperature: 1 + max_tokens: 512 + +miskito_all/fewshot_direct/gpt-4-32k: + class: evals.solvers.nested.fewshot_solver:FewShotSolver + args: + train_jsonl: evals/registry/data/skill_acquisition/miskito/variants/miskito_train_all.jsonl + n_shots: 3 + base_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-32k + extra_options: + temperature: 1 + max_tokens: 512 + +miskito_all/fewshot_direct/gpt-4-base: + class: evals.solvers.nested.fewshot_solver:FewShotSolver + args: + train_jsonl: evals/registry/data/skill_acquisition/miskito/variants/miskito_train_all.jsonl + n_shots: 3 + base_solver: + class: evals.solvers.nested.hhh_solver:HHHSolver + args: + solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-base + extra_options: + temperature: 1 + max_tokens: 512 + +miskito_manipulation/fewshot_direct/gpt-4-32k: + class: evals.solvers.nested.fewshot_solver:FewShotSolver + args: + train_jsonl: evals/registry/data/skill_acquisition/miskito/variants/miskito_train_manipulation.jsonl + n_shots: 3 + base_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-32k + extra_options: + temperature: 1 + max_tokens: 512 + +miskito_manipulation/fewshot_direct/gpt-4-base: + class: evals.solvers.nested.fewshot_solver:FewShotSolver + args: + train_jsonl: evals/registry/data/skill_acquisition/miskito/variants/miskito_train_manipulation.jsonl + n_shots: 3 + base_solver: + class: evals.solvers.nested.hhh_solver:HHHSolver + args: + solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-base + extra_options: + temperature: 1 + max_tokens: 512 + +# OS models +skill_acquisition/cot/llama-2-13b-chat: + class: evals.solvers.nested.cot_solver:CoTSolver + args: + cot_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: meta-llama/Llama-2-13b-chat-hf + extra_options: + temperature: 1 + max_tokens: 512 + extract_template: *extract_template + extract_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: meta-llama/Llama-2-13b-chat-hf + extra_options: + temperature: 1 + max_tokens: 512 + +skill_acquisition/cot/llama-2-70b-chat: + class: evals.solvers.nested.cot_solver:CoTSolver + args: + cot_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: meta-llama/Llama-2-70b-chat-hf + extra_options: + temperature: 1 + max_tokens: 512 + extract_template: *extract_template + extract_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: meta-llama/Llama-2-70b-chat-hf + extra_options: + temperature: 1 + max_tokens: 512 + +skill_acquisition/cot/mixtral-8x7b-instruct: + class: evals.solvers.nested.cot_solver:CoTSolver + args: + cot_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: mistralai/Mixtral-8x7B-Instruct-v0.1 + extra_options: + temperature: 1 + max_tokens: 512 + extract_template: *extract_template + extract_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: mistralai/Mixtral-8x7B-Instruct-v0.1 + extra_options: + temperature: 1 + max_tokens: 512 diff --git a/evals/registry/solvers/track_the_stat.yaml b/evals/registry/solvers/track_the_stat.yaml new file mode 100644 index 0000000000..fc061f9583 --- /dev/null +++ b/evals/registry/solvers/track_the_stat.yaml @@ -0,0 +1,82 @@ +track_the_stat/explicit_state/gemini-pro: + class: evals.elsuite.track_the_stat.solvers:ExplicitStateSolver + args: + underlying_solver: + class: evals.solvers.providers.google.gemini_solver:GeminiSolver + args: + model_name: gemini-pro + state_role: "user" + +track_the_stat/explicit_state/llama-2-70b-chat: + class: evals.elsuite.track_the_stat.solvers:ExplicitStateSolver + args: + underlying_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: meta-llama/Llama-2-70b-chat-hf + extra_options: + temperature: 1 + max_tokens: 512 + +track_the_stat/explicit_state/mixtral-8x7b-instruct: + class: evals.elsuite.track_the_stat.solvers:ExplicitStateSolver + args: + underlying_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: mistralai/Mixtral-8x7B-Instruct-v0.1 + extra_options: + temperature: 1 + max_tokens: 512 + +track_the_stat/explicit_state/gpt-3.5-turbo: + class: evals.elsuite.track_the_stat.solvers:ExplicitStateSolver + args: + underlying_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-3.5-turbo + extra_options: + temperature: 1 + max_tokens: 512 + +track_the_stat/explicit_state/gpt-4-turbo-preview: + class: evals.elsuite.track_the_stat.solvers:ExplicitStateSolver + args: + underlying_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-turbo-preview + extra_options: + temperature: 1 + max_tokens: 512 + +track_the_stat/explicit_state/hhh/gpt-4-base: + class: evals.elsuite.track_the_stat.solvers:ExplicitStateSolver + args: + underlying_solver: + class: evals.solvers.nested.hhh_solver:HHHSolver + args: + solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-base + extra_options: + temperature: 1 + max_tokens: 512 + +track_the_stat/human_cli: + class: evals.elsuite.track_the_stat.solvers:TrackTheStatHuman + args: + human_cli_solver: + class: evals.solvers.human_cli_solver:HumanCliSolver + args: + registry: null + +track_the_stat/random_baseline: + class: evals.elsuite.track_the_stat.solvers:RandomBaselineSolver diff --git a/evals/registry/solvers/twenty_questions.yaml b/evals/registry/solvers/twenty_questions.yaml new file mode 100644 index 0000000000..81cc65468c --- /dev/null +++ b/evals/registry/solvers/twenty_questions.yaml @@ -0,0 +1,80 @@ +# CoT solvers with a custom extract template. +twenty_questions/cot/gpt-3.5-turbo: + class: evals.solvers.nested.cot_solver:CoTSolver + args: + cot_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-3.5-turbo + extra_options: + temperature: 1 + max_tokens: 512 + extract_template: &extract_template Given the above reasoning, ask a question or make a guess following the task instructions. + extract_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-3.5-turbo + extra_options: + temperature: 1 + max_tokens: 512 + +twenty_questions/cot/gpt-4-turbo-preview: + class: evals.solvers.nested.cot_solver:CoTSolver + args: + cot_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-turbo-preview + extra_options: + temperature: 1 + max_tokens: 512 + extract_template: *extract_template + extract_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-turbo-preview + extra_options: + temperature: 1 + max_tokens: 512 + +twenty_questions/cot_hhh/gpt-4-base: + class: evals.solvers.nested.cot_solver:CoTSolver + args: + cot_solver: + class: evals.solvers.nested.hhh_solver:HHHSolver + args: + solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-base + extra_options: + temperature: 1 + max_tokens: 512 + extract_template: *extract_template + extract_solver: + class: evals.solvers.nested.hhh_solver:HHHSolver + args: + solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-base + extra_options: + temperature: 1 + max_tokens: 512 + +# Game-master uses a fixed solver, currently set to the latest-generation model. +twenty_questions/gamemaster/gpt-4-turbo-preview: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-turbo-preview + extra_options: + temperature: 0 + max_tokens: 1 + valid_answers: ["yes", "no", "skip"] \ No newline at end of file diff --git a/evals/utils/log_utils.py b/evals/utils/log_utils.py index d54a846f41..6ef2b5e8ff 100644 --- a/evals/utils/log_utils.py +++ b/evals/utils/log_utils.py @@ -14,6 +14,17 @@ def get_final_results_from_dir(log_dir: Union[str, Path]) -> dict[Path, dict]: return final_results_dict +def get_specs_from_dir(log_dir: Union[str, Path]) -> dict[Path, dict]: + """ + Given a directory of log files, return a dictionary mapping log file paths to specs. + """ + specs_dict = {} + for path in Path(log_dir).glob("**/*.log"): + spec = extract_spec(path) + specs_dict[path] = spec + return specs_dict + + def extract_final_results(path: Path) -> dict: """ Given a path to a log file, find and return the "final_report" dictionary. @@ -31,7 +42,7 @@ def extract_final_results(path: Path) -> dict: raise ValueError(f"Could not find final_report in {path}") -def extract_individual_results(path: Path) -> list[dict]: +def extract_individual_results(path: Path, type_string: str = "metrics") -> list[dict]: """ Given a path to a log file, grab all the individual sample results. """ @@ -42,7 +53,7 @@ def extract_individual_results(path: Path) -> list[dict]: try: loaded_line = json.loads(line) if "type" in loaded_line: - if loaded_line["type"] == "metrics": + if loaded_line["type"] == type_string: all_data.append(loaded_line["data"]) except json.decoder.JSONDecodeError: print(f"Skipping line: {line}") diff --git a/pyproject.toml b/pyproject.toml index e21afbf479..557d23a946 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,9 @@ dependencies = [ "aiolimiter", "beartype==0.12.0", "flask", + "gymnasium", + "networkx", + "chess", ] [project.urls]