From b4e7c9953e0da22a6c7dc6b22174305e466388f9 Mon Sep 17 00:00:00 2001 From: Giulio Starace Date: Tue, 19 Mar 2024 11:11:49 +0100 Subject: [PATCH] open source anthropic solver --- evals/registry/solvers/anthropic.yaml | 125 +++++++++++++++ .../providers/anthropic/anthropic_solver.py | 148 ++++++++++++++++++ .../anthropic/anthropic_solver_test.py | 131 ++++++++++++++++ pyproject.toml | 1 + 4 files changed, 405 insertions(+) create mode 100644 evals/registry/solvers/anthropic.yaml create mode 100644 evals/solvers/providers/anthropic/anthropic_solver.py create mode 100644 evals/solvers/providers/anthropic/anthropic_solver_test.py diff --git a/evals/registry/solvers/anthropic.yaml b/evals/registry/solvers/anthropic.yaml new file mode 100644 index 0000000000..325f4c1cc2 --- /dev/null +++ b/evals/registry/solvers/anthropic.yaml @@ -0,0 +1,125 @@ +# ------------------ +# claude-3-opus-20240229 +# ------------------ + +generation/direct/claude-3-opus-20240229: + class: evals.solvers.providers.anthropic.anthropic_solver:AnthropicSolver + args: + model_name: claude-3-opus-20240229 + +generation/cot/claude-3-opus-20240229: + class: evals.solvers.nested.cot_solver:CoTSolver + args: + cot_solver: + class: evals.solvers.providers.anthropic.anthropic_solver:AnthropicSolver + args: + model_name: claude-3-opus-20240229 + extract_solver: + class: evals.solvers.providers.anthropic.anthropic_solver:AnthropicSolver + args: + model_name: claude-3-opus-20240229 + +# ------------------ +# claude-3-sonnet-20240229 +# ------------------ + +generation/direct/claude-3-sonnet-20240229: + class: evals.solvers.providers.anthropic.anthropic_solver:AnthropicSolver + args: + model_name: claude-3-sonnet-20240229 + +generation/cot/claude-3-sonnet-20240229: + class: evals.solvers.nested.cot_solver:CoTSolver + args: + cot_solver: + class: evals.solvers.providers.anthropic.anthropic_solver:AnthropicSolver + args: + model_name: claude-3-sonnet-20240229 + extract_solver: + class: evals.solvers.providers.anthropic.anthropic_solver:AnthropicSolver + args: + model_name: claude-3-sonnet-20240229 + +# ------------------ +# claude-3-haiku-20240307 +# ------------------ + +generation/direct/claude-3-haiku-20240307: + class: evals.solvers.providers.anthropic.anthropic_solver:AnthropicSolver + args: + model_name: claude-3-haiku-20240307 + +generation/cot/claude-3-haiku-20240307: + class: evals.solvers.nested.cot_solver:CoTSolver + args: + cot_solver: + class: evals.solvers.providers.anthropic.anthropic_solver:AnthropicSolver + args: + model_name: claude-3-haiku-20240307 + extract_solver: + class: evals.solvers.providers.anthropic.anthropic_solver:AnthropicSolver + args: + model_name: claude-3-haiku-20240307 + +# ------------------ +# claude-2.1 +# ------------------ + +generation/direct/claude-2.1: + class: evals.solvers.providers.anthropic.anthropic_solver:AnthropicSolver + args: + model_name: claude-2.1 + +generation/cot/claude-2.1: + class: evals.solvers.nested.cot_solver:CoTSolver + args: + cot_solver: + class: evals.solvers.providers.anthropic.anthropic_solver:AnthropicSolver + args: + model_name: claude-2.1 + extract_solver: + class: evals.solvers.providers.anthropic.anthropic_solver:AnthropicSolver + args: + model_name: claude-2.1 + +# ------------------ +# claude-2.0 +# ------------------ + +generation/direct/claude-2.0: + class: evals.solvers.providers.anthropic.anthropic_solver:AnthropicSolver + args: + model_name: claude-2.0 + +generation/cot/claude-2.0: + class: evals.solvers.nested.cot_solver:CoTSolver + args: + cot_solver: + class: evals.solvers.providers.anthropic.anthropic_solver:AnthropicSolver + args: + model_name: claude-2.0 + extract_solver: + class: evals.solvers.providers.anthropic.anthropic_solver:AnthropicSolver + args: + model_name: claude-2.0 + +# ------------------ +# claude-instant-1.2 +# ------------------ + +generation/direct/claude-instant-1.2: + class: evals.solvers.providers.anthropic.anthropic_solver:AnthropicSolver + args: + model_name: claude-instant-1.2 + +generation/cot/claude-instant-1.2: + class: evals.solvers.nested.cot_solver:CoTSolver + args: + cot_solver: + class: evals.solvers.providers.anthropic.anthropic_solver:AnthropicSolver + args: + model_name: claude-instant-1.2 + extract_solver: + class: evals.solvers.providers.anthropic.anthropic_solver:AnthropicSolver + args: + model_name: claude-instant-1.2 diff --git a/evals/solvers/providers/anthropic/anthropic_solver.py b/evals/solvers/providers/anthropic/anthropic_solver.py new file mode 100644 index 0000000000..9f0766598d --- /dev/null +++ b/evals/solvers/providers/anthropic/anthropic_solver.py @@ -0,0 +1,148 @@ +from typing import Any, Optional, Union + +from evals.solvers.solver import Solver, SolverResult +from evals.task_state import TaskState, Message +from evals.record import record_sampling +from evals.utils.api_utils import request_with_timeout + +import anthropic +from anthropic import Anthropic +from anthropic.types import ContentBlock, MessageParam, Usage +import backoff + +oai_to_anthropic_role = { + "system": "user", + "user": "user", + "assistant": "assistant", +} + + +class AnthropicSolver(Solver): + """ + A solver class that uses the Anthropic API for textual chat-based tasks. + """ + + def __init__( + self, + model_name: str, + max_tokens: int = 512, + postprocessors: list[str] = [], + extra_options: Optional[dict] = {}, + registry: Any = None, + ): + super().__init__(postprocessors=postprocessors) + # https://docs.anthropic.com/claude/docs/models-overview#model-comparison + self.model_name = model_name + self.max_tokens = max_tokens + self.extra_options = extra_options + + def _solve(self, task_state: TaskState, **kwargs) -> SolverResult: + """ + Solve the task using the Anthropic API + """ + orig_msgs = task_state.messages + anth_msgs = self._convert_msgs_to_anthropic_format(task_state.messages) + + # TODO: handle context length limit; possible once anthropic tokenizer is available + + # calls client.messages.create, but is wrapped with backoff retrying decorator + response = anthropic_create_retrying( + client=Anthropic(max_retries=0), # we take care of retries ourselves + model=self.model_name, + system=task_state.task_description, + messages=anth_msgs, + max_tokens=self.max_tokens, # required kwarg for messages.create + **{**kwargs, **self.extra_options}, + ) + solver_result = SolverResult( + output=response.content[0].text, raw_completion_result=response.content + ) + + # for logging purposes: prepend the task desc to the orig msgs as a system message + orig_msgs.insert( + 0, Message(role="system", content=task_state.task_description).to_dict() + ) + record_sampling( + prompt=orig_msgs, # original message format, supported by our logviz + sampled=[solver_result.output], + model=self.model_name, + usage=anth_to_openai_usage(response.usage), + ) + return solver_result + + @property + def name(self) -> str: + return self.model_name + + @property + def model_version(self) -> Union[str, dict]: + """ + For the moment, Anthropic does not use aliases, + so model_version is the same as model_name. + """ + return self.model_name + + @staticmethod + def _convert_msgs_to_anthropic_format(msgs: list[Message]) -> list[MessageParam]: + """ + Anthropic API requires that the message list has + - Roles as 'user' or 'assistant' + - Alternating 'user' and 'assistant' messages + + Note: the top-level system prompt is handled separately and should not be + included in the messages list. + """ + # enforce valid roles; convert to Anthropic message type + anth_msgs = [ + MessageParam( + role=oai_to_anthropic_role[msg.role], + content=[ContentBlock(text=msg.content, type="text")], + ) + for msg in msgs + ] + # enforce alternating roles by merging consecutive messages with the same role + # e.g. [user1, user2, assistant1, user3] -> [user12, assistant1, user3] + alt_msgs = [] + for msg in anth_msgs: + if len(alt_msgs) > 0 and msg["role"] == alt_msgs[-1]["role"]: + # Merge consecutive messages from the same role + alt_msgs[-1]["content"].extend(msg["content"]) + else: + alt_msgs.append(msg) + + return alt_msgs + + +@backoff.on_exception( + wait_gen=backoff.expo, + exception=( + anthropic.RateLimitError, + anthropic.APIConnectionError, + anthropic.APITimeoutError, + anthropic.InternalServerError, + ), + max_value=60, + factor=1.5, +) +def anthropic_create_retrying(client: Anthropic, *args, **kwargs): + """ + Helper function for creating a backoff-retry enabled message request. + `args` and `kwargs` match what is accepted by `client.messages.create`. + """ + result = request_with_timeout(client.messages.create, *args, **kwargs) + if "error" in result: + raise Exception(result["error"]) + return result + + +def anth_to_openai_usage(anth_usage: Usage) -> dict: + """ + Processes anthropic Usage object into dict with keys + that match the OpenAI Usage dict, for logging purposes. + """ + # TODO: make this format of dict a dataclass type to be reused througout lib? + return { + "completion_tokens": anth_usage.output_tokens, + "prompt_tokens": anth_usage.input_tokens, + "total_tokens": anth_usage.input_tokens + anth_usage.output_tokens, + } diff --git a/evals/solvers/providers/anthropic/anthropic_solver_test.py b/evals/solvers/providers/anthropic/anthropic_solver_test.py new file mode 100644 index 0000000000..9ba8fb1470 --- /dev/null +++ b/evals/solvers/providers/anthropic/anthropic_solver_test.py @@ -0,0 +1,131 @@ +import os +import pytest + +from evals.record import DummyRecorder +from evals.task_state import Message, TaskState +from evals.solvers.providers.anthropic.anthropic_solver import ( + AnthropicSolver, + anth_to_openai_usage, +) + +from anthropic.types import ContentBlock, MessageParam, Usage + +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" +MODEL_NAME = "claude-instant-1.2" + + +@pytest.fixture +def anthropic_solver(): + solver = AnthropicSolver( + model_name=MODEL_NAME, + ) + return solver + + +@pytest.fixture +def dummy_recorder(): + """ + Sets the "default_recorder" necessary for sampling in the solver. + """ + recorder = DummyRecorder(None) # type: ignore + with recorder.as_default_recorder("x"): + yield recorder + + +@pytest.mark.skipif( + IN_GITHUB_ACTIONS, reason="API tests are wasteful to run on every commit." +) +def test_solver(dummy_recorder, anthropic_solver): + """ + Test that the solver generates a response coherent with the message history + while following the instructions from the task description. + - checks the task description is understood + - checks that the messages are understood + """ + solver = anthropic_solver + + answer = "John Doe" + task_state = TaskState( + task_description=f"When you are asked for your name, respond with '{answer}' (without quotes).", + messages=[ + Message(role="user", content="What is 2 + 2?"), + Message(role="assistant", content="4"), + Message(role="user", content="What is your name?"), + ], + ) + + solver_res = solver(task_state=task_state) + assert ( + solver_res.output == answer + ), f"Expected '{answer}', but got {solver_res.output}" + + +def test_message_format(): + """ + Test that messages in our evals format are correctly + converted to the format expected by Anthropic + - "system" messages mapped to "user" in Anthropic + - messages must alternate between "user" and "assistant" + - messages are in MessageParam format + """ + msgs = [ + Message(role="user", content="What is 2 + 2?"), + Message(role="system", content="reason step by step"), + Message( + role="assistant", content="I don't need to reason for this, 2+2 is just 4" + ), + Message(role="system", content="now, given your reasoning, provide the answer"), + ] + anth_msgs = AnthropicSolver._convert_msgs_to_anthropic_format(msgs) + + expected = [ + MessageParam( + role="user", + content=[ + ContentBlock(text="What is 2 + 2?", type="text"), + ContentBlock(text="reason step by step", type="text"), + ], + ), + MessageParam( + role="assistant", + content=[ + ContentBlock( + text="I don't need to reason for this, 2+2 is just 4", type="text" + ), + ], + ), + MessageParam( + role="user", + content=[ + ContentBlock( + text="now, given your reasoning, provide the answer", type="text" + ), + ], + ), + ] + + assert anth_msgs == expected, f"Expected {expected}, but got {anth_msgs}" + + +def test_anth_to_openai_usage_correctness(): + usage = Usage(input_tokens=100, output_tokens=150) + expected = { + "completion_tokens": 150, + "prompt_tokens": 100, + "total_tokens": 250, + } + assert ( + anth_to_openai_usage(usage) == expected + ), "The conversion does not match the expected format." + + +def test_anth_to_openai_usage_zero_tokens(): + usage = Usage(input_tokens=0, output_tokens=0) + expected = { + "completion_tokens": 0, + "prompt_tokens": 0, + "total_tokens": 0, + } + assert ( + anth_to_openai_usage(usage) == expected + ), "Zero token cases are not handled correctly." diff --git a/pyproject.toml b/pyproject.toml index 4c4e6cbfa9..dabe5466c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "seaborn", "statsmodels", "chess", + "anthropic" ] [project.urls]