Skip to content

Commit

Permalink
Updates on existing solvers and bugged tool eval (#1506)
Browse files Browse the repository at this point in the history
@JunShern will review this

Wrap solvers with completion functions for compatibility with pre-solver
Evals. This means you can execute all evals using solvers.
[49fd9ef](49fd9ef)

Add context length information about gpt-4-turbo-preview and
gpt-4-0125-preview.
[9a0ab1c](9a0ab1c)

Move oai and together solvers into providers / subdir
[063bf4f](063bf4f)

Update the default task descriptions for bugged tools. We added more
information when using gemini + OS models, since they got confused.
[0523dd4](0523dd4)

Modified the default solver chain-of-thought prompt, as well as other
custom chain-of-thought prompts used in some evals. The default
CoTSolver prompts were a bit misleading in some cases; we observed
GeminiSolver working too hard to arrive at a final answer for the whole
eval when it's in fact supposed to give just a response for the next
turn.
[287f3cf](287f3cf)

---------

Co-authored-by: johny-b <33967107+johny-b@users.noreply.github.com>
Co-authored-by: Chan Jun Shern <JunShern@users.noreply.github.com>
Co-authored-by: Giulio Starace <giulio.starace@gmail.com>
  • Loading branch information
4 people authored Mar 28, 2024
1 parent d9d2f5f commit 2420c62
Show file tree
Hide file tree
Showing 39 changed files with 533 additions and 283 deletions.
73 changes: 73 additions & 0 deletions evals/completion_fns/solver_completion_fn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Any, Union

from evals.api import CompletionFn, CompletionResult
from evals.prompt.base import OpenAICreateChatPrompt
from evals.solvers.nested.cot_solver import CoTSolver
from evals.solvers.solver import Solver, SolverSpec, create_solver
from evals.task_state import Message, TaskState


class SolverCompletionFnResult(CompletionResult):
def __init__(self, msg):
self.msg = msg

def get_completions(self):
return [self.msg]


class SolverCompletionFn(CompletionFn):
"""
Wraps a solver into a completion function, s.t. that the completion function's
__call__ method calls the internal solver's _solve method, mapping the input
completion function `prompt` to the solver's `task_state` input.
Useful for using Solvers with eval.Eval classes, which would normally require a CompletionFn.
Current limitations:
- Stateful solvers are not supported: Solver state is not maintained between
calls.
- Prompts with more than `role` and `content` keys are not supported.
"""

def __init__(self, solver: Union[SolverSpec, Solver], registry: Any = None):
if isinstance(solver, Solver):
self.solver = solver
else:
self.solver = create_solver(solver)

def __call__(
self, prompt: Union[str, OpenAICreateChatPrompt], **kwargs
) -> SolverCompletionFnResult:
# We have this check here rather than __init__ since the solver may be unwrapped and used in a SolverEval
if isinstance(self.solver, CoTSolver):
if self.solver.interaction_cache is not None:
raise ValueError(
"`CoTSolver` with persistent memory is incompatible with "
"CompletionFn-based `Eval` classes. "
"Please set `CoTSolver(persistent_memory=False)` or update the eval to a `SolverEval`."
)

if isinstance(prompt, str):
prompt = [{"role": "system", "content": prompt}]
elif isinstance(prompt, list):
assert prompt[0]["role"] == "system", "Unexpected prompt role ordering"
else:
raise ValueError(
f"Unexpected prompt type: "
f"string or OpenAICreateChatPrompt expected, got {type(prompt)}"
)

assert set(prompt[0].keys()) == {"role", "content",}, (
"Unexpected keys in prompt: "
f"expected exactly {{'role', 'content'}}, got {set(prompt[0].keys())}"
)
task_state = TaskState(
prompt[0]["content"],
[Message(msg["role"], msg["content"]) for msg in prompt[1:]],
)

# use a copy to avoid task state surviving across samples
pure_solver = self.solver.copy()

result = pure_solver(task_state, **kwargs)
return SolverCompletionFnResult(result.output)
4 changes: 2 additions & 2 deletions evals/elsuite/bluff/strategy_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import Optional

from evals.elsuite.bluff.bluff.cards import get_bluff_move
from evals.solvers.memory import PersistentMemoryCache
from evals.solvers.solver import Solver, SolverResult
from evals.solvers.utils import PersistentMemoryCache
from evals.task_state import Message, TaskState


Expand All @@ -28,7 +28,7 @@ def __init__(

# interaction_length=1 to store reasoning step in private memory
self.interaction_cache = PersistentMemoryCache(interaction_length=1)

def _generate_response(self, task_state: TaskState):
"""
Calls base solver. Modifies taks state to remove all non-reasoning messages
Expand Down
2 changes: 2 additions & 0 deletions evals/elsuite/bugged_tools/task_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
You have access to tools that may help you at solving this task. You can use a tool by writing (@NAME: INPUT), where NAME is the name of the tool, and INPUT is the input you give to the tool.
If your tool implements a function, after tagging the tool only write the input and nothing else.
You can call multiple tools at the same time. Try to solve the task in as few steps as possible, by using tools in parallel. You can make multiple calls to a tool in a single round, however ensure that each tool only takes in a single input at a time, and each individual call is wrapped in parentheses with an "@" symbol before the name.
When you call a tool, I will give you the tools output in the next message. Do not generate the outputs of the tools yourself, I will do it for you. Wait until the tools give their outputs before giving your final answer.
{bug_description}
You have access to the following tools:
Expand All @@ -27,6 +28,7 @@
You have access to tools that may help you at solving this task. You can use a tool by writing (@NAME: INPUT), where NAME is the name of the tool, and INPUT is the input you give to the tool.
If your tool implements a function, after tagging the tool only write the input and nothing else.
You can call multiple tools at the same time. Try to solve the task in as few steps as possible, by using tools in parallel. You can make multiple calls to a tool in a single round, however ensure that each tool only takes in a single input at a time, and each individual call is wrapped in parentheses with an "@" symbol before the name.
When you call a tool, I will give you the tools output in the next message. Do not generate the outputs of the tools yourself, I will do it for you. Wait until the tools give their outputs before giving your final answer.
{bug_description}
You have access to the following tools:
Expand Down
4 changes: 3 additions & 1 deletion evals/elsuite/function_deduction/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
new_value = "calc({in_}) = {out}"
test_input_not_allowed = "Numbers {inputs} are not allowed"

cot_template_first_round = """This is the first round. Think out loud about a general strategy of solving tasks like this."""
cot_template_first_round = """This is the first round. Think out loud about a general strategy of solving tasks like this. (You now have an opportunity to reason privately before submitting an action. Your next response will not be evaluated.)"""
cot_template_later_rounds = """\
This is round {round_ix} out of {num_rounds}.
Expand All @@ -40,4 +40,6 @@
* Do you have any hypothesis on what this function might be doing?
* If yes, should you try to test it (how?), or just use it to calculate the answer?
* If not, what additional information should you gather to be able to formulate a hypothesis?
(You now have an opportunity to reason privately before submitting an action. Your next response will not be evaluated.)
"""
2 changes: 1 addition & 1 deletion evals/elsuite/function_deduction/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _few_shot_messages(self, current_state) -> list[Message]:
(
"assistant",
(
"Now we have two results where the ouput is the input times 4. It seems that the function multiplies by 4. "
"Now we have two results where the output is the input times 4. It seems that the function multiplies by 4. "
"I will make the guess now. 17 * 4 = 68, 36 * 4 = 144 and 68 * 4 = 272, so my guess will be 68 144 272."
),
)
Expand Down
18 changes: 10 additions & 8 deletions evals/elsuite/function_deduction/solvers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,43 +85,43 @@ def test_base_model_cot_solver():
Message(role="system", content="I have a hidden function. What is your first action?"),
Message(
role="system",
content="This is the first round. Think out loud about a general strategy of solving tasks like this.",
content="This is the first round. Think out loud about a general strategy of solving tasks like this. (You now have an opportunity to reason privately before submitting an action. Your next response will not be evaluated.)",
),
Message(
role="assistant",
content="I don't know anything about the function yet. I should start with gathering some information, e.g. by asking about the return value for 40.",
),
Message(
role="system",
content="Given the above reasoning, the answer in the format requested by the question is:",
content="Given the above reasoning, your response in the format requested by the instructions is:",
),
Message(role="assistant", content="40"),
Message(role="system", content="calc(40) = 160"),
Message(
role="system",
content="This is round 2 out of 10.\n\nSummary of the information you have already gathered:\ncalc(40) = 160\n\nThink out loud about the following questions:\n* Do you have any hypothesis on what this function might be doing?\n* If yes, should you try to test it (how?), or just use it to calculate the answer?\n* If not, what additional information should you gather to be able to formulate a hypothesis?\n",
content="This is round 2 out of 10.\n\nSummary of the information you have already gathered:\ncalc(40) = 160\n\nThink out loud about the following questions:\n* Do you have any hypothesis on what this function might be doing?\n* If yes, should you try to test it (how?), or just use it to calculate the answer?\n* If not, what additional information should you gather to be able to formulate a hypothesis?\n\n(You now have an opportunity to reason privately before submitting an action. Your next response will not be evaluated.)",
),
Message(
role="assistant",
content="Perhaps this function does multiplication by 4? That's not certain yet, so let's try some other number, e.g. 52.",
),
Message(
role="system",
content="Given the above reasoning, the answer in the format requested by the question is:",
content="Given the above reasoning, your response in the format requested by the instructions is:",
),
Message(role="assistant", content="52"),
Message(role="system", content="calc(52) = 204"),
Message(
role="system",
content="This is round 3 out of 10.\n\nSummary of the information you have already gathered:\ncalc(40) = 160\ncalc(52) = 204\n\nThink out loud about the following questions:\n* Do you have any hypothesis on what this function might be doing?\n* If yes, should you try to test it (how?), or just use it to calculate the answer?\n* If not, what additional information should you gather to be able to formulate a hypothesis?\n",
content="This is round 3 out of 10.\n\nSummary of the information you have already gathered:\ncalc(40) = 160\ncalc(52) = 204\n\nThink out loud about the following questions:\n* Do you have any hypothesis on what this function might be doing?\n* If yes, should you try to test it (how?), or just use it to calculate the answer?\n* If not, what additional information should you gather to be able to formulate a hypothesis?\n\n(You now have an opportunity to reason privately before submitting an action. Your next response will not be evaluated.)",
),
Message(
role="assistant",
content="Now we have two results where the ouput is the input times 4. It seems that the function multiplies by 4. I will make the guess now. 17 * 4 = 68, 36 * 4 = 144 and 68 * 4 = 272, so my guess will be 68 144 272.",
content="Now we have two results where the output is the input times 4. It seems that the function multiplies by 4. I will make the guess now. 17 * 4 = 68, 36 * 4 = 144 and 68 * 4 = 272, so my guess will be 68 144 272.",
),
Message(
role="system",
content="Given the above reasoning, the answer in the format requested by the question is:",
content="Given the above reasoning, your response in the format requested by the instructions is:",
),
Message(role="assistant", content="68 144 272"),
Message(role="system", content="Correct guess!"),
Expand All @@ -130,7 +130,9 @@ def test_base_model_cot_solver():
content="I now have a new function. Forget about the previous one, we start again.",
),
]
assert solver_private_memory[: len(expected_few_shot_msgs)] == expected_few_shot_msgs
for i in range(len(expected_few_shot_msgs)):
assert solver_private_memory[i].role == expected_few_shot_msgs[i].role
assert solver_private_memory[i].content.strip() == expected_few_shot_msgs[i].content.strip()
assert (
solver_private_memory[len(expected_few_shot_msgs) + 0].content == cot_template_first_round
)
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/hr_ml_agent_bench/solvers/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tiktoken

from evals.registry import Registry, n_ctx_from_model_name
from evals.solvers.openai_solver import OpenAISolver
from evals.solvers.providers.openai.openai_solver import OpenAISolver
from evals.solvers.solver import Solver, SolverResult
from evals.task_state import Message, TaskState

Expand Down
21 changes: 8 additions & 13 deletions evals/elsuite/make_me_say/eval.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,26 @@
import numpy as np

import evals
from evals.api import CompletionFn, DummyCompletionFn
from evals.api import DummyCompletionFn
from evals.elsuite.make_me_say.autoeval import run as run_auto_eval
from evals.elsuite.make_me_say.core import Game
from evals.record import RecorderBase


class MakeMeSay(evals.Eval):
def __init__(
self,
completion_fns: list[CompletionFn],
*args,
**kwargs,
):
super().__init__(completion_fns, *args, **kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

if len(completion_fns) == 1 and isinstance(completion_fns[0], DummyCompletionFn):
completion_fn = completion_fns[0]
completion_fns = [completion_fn for _ in range(3)]
if len(self.completion_fns) == 1 and isinstance(self.completion_fns[0], DummyCompletionFn):
completion_fn = self.completion_fns[0]
self.completion_fns = [completion_fn for _ in range(3)]

assert len(completion_fns) == 3, "MakeMeSay only supports three completion fns"
assert len(self.completion_fns) == 3, "MakeMeSay only supports three completion fns"
(
self.manipulator_completion_fn,
self.manipulatee_completion_fn,
self.judge_completion_fn,
) = completion_fns
) = self.completion_fns

def eval_sample(self, sample: dict, rng) -> None:
del rng
Expand Down
8 changes: 4 additions & 4 deletions evals/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import random
from multiprocessing.pool import ThreadPool
from pathlib import Path
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union

from tqdm import tqdm

Expand All @@ -18,7 +18,7 @@
from .record import RecorderBase
from .registry import Registry
from .solvers.solver import Solver
from .solvers.utils import maybe_wrap_with_solver
from .solvers.utils import maybe_wrap_with_compl_fn, maybe_wrap_with_solver

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,7 +55,7 @@ class Eval(abc.ABC):

def __init__(
self,
completion_fns: list[CompletionFn],
completion_fns: list[Union[CompletionFn, Solver]],
eval_registry_path: Path,
seed: int = 20220722,
name: str = "no_name_eval.default",
Expand All @@ -66,7 +66,7 @@ def __init__(
if len(splits) < 2:
raise ValueError(f"Eval name must at least have <base_eval>.<split>. Got name {name}")

self.completion_fns = completion_fns
self.completion_fns = [maybe_wrap_with_compl_fn(fn) for fn in completion_fns]
self.eval_registry_path = eval_registry_path
self.seed = seed
self.name = name
Expand Down
2 changes: 2 additions & 0 deletions evals/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def n_ctx_from_model_name(model_name: str) -> Optional[int]:
"gpt-4-32k": 32768,
"gpt-4-base": 8192,
"gpt-4-1106-preview": 128_000,
"gpt-4-turbo-preview": 128_000,
"gpt-4-0125-preview": 128_000,
}

# first, look for an exact match
Expand Down
12 changes: 6 additions & 6 deletions evals/registry/solvers/already_said_that.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ already_said_that/cot/gpt-3.5-turbo:
args:
persistent_memory: False
cot_solver:
class: evals.solvers.openai_solver:OpenAISolver
class: evals.solvers.providers.openai.openai_solver:OpenAISolver
args:
completion_fn_options:
model: gpt-3.5-turbo
extra_options:
temperature: 1
max_tokens: 512
extract_solver:
class: evals.solvers.openai_solver:OpenAISolver
class: evals.solvers.providers.openai.openai_solver:OpenAISolver
args:
completion_fn_options:
model: gpt-3.5-turbo
Expand All @@ -35,15 +35,15 @@ already_said_that/cot/gpt-4-turbo-preview:
args:
persistent_memory: False
cot_solver:
class: evals.solvers.openai_solver:OpenAISolver
class: evals.solvers.providers.openai.openai_solver:OpenAISolver
args:
completion_fn_options:
model: gpt-4-turbo-preview
extra_options:
temperature: 1
max_tokens: 512
extract_solver:
class: evals.solvers.openai_solver:OpenAISolver
class: evals.solvers.providers.openai.openai_solver:OpenAISolver
args:
completion_fn_options:
model: gpt-4-turbo-preview
Expand All @@ -59,7 +59,7 @@ already_said_that/cot_hhh/gpt-4-base:
class: evals.solvers.nested.hhh_solver:HHHSolver
args:
solver:
class: evals.solvers.openai_solver:OpenAISolver
class: evals.solvers.providers.openai.openai_solver:OpenAISolver
args:
completion_fn_options:
model: gpt-4-base
Expand All @@ -70,7 +70,7 @@ already_said_that/cot_hhh/gpt-4-base:
class: evals.solvers.nested.hhh_solver:HHHSolver
args:
solver:
class: evals.solvers.openai_solver:OpenAISolver
class: evals.solvers.providers.openai.openai_solver:OpenAISolver
args:
completion_fn_options:
model: gpt-4-base
Expand Down
Loading

0 comments on commit 2420c62

Please sign in to comment.