From 5a92ac38155cb32dcde1cc8b69b5e002e9437532 Mon Sep 17 00:00:00 2001 From: Oliver Jaffe Date: Tue, 26 Mar 2024 15:27:13 +0000 Subject: [PATCH] Add Gemini Solver (#1503) Adds a solver for Gemini 1.5 Pro. Stacked on #1501 and #1482. Using the solver requires the `GEMINI_API_KEY` environment variable Test with: ``` oaieval generation/direct/gemini-pro bugged_tools ``` --------- Co-authored-by: Chan Jun Shern --- evals/registry/solvers/gemini.yaml | 23 ++ .../solvers/providers/google/gemini_solver.py | 211 ++++++++++++++++++ .../providers/google/gemini_solver_test.py | 71 ++++++ .../solvers/providers/google/requirements.txt | 1 + pyproject.toml | 3 +- 5 files changed, 308 insertions(+), 1 deletion(-) create mode 100644 evals/registry/solvers/gemini.yaml create mode 100644 evals/solvers/providers/google/gemini_solver.py create mode 100644 evals/solvers/providers/google/gemini_solver_test.py create mode 100644 evals/solvers/providers/google/requirements.txt diff --git a/evals/registry/solvers/gemini.yaml b/evals/registry/solvers/gemini.yaml new file mode 100644 index 0000000000..eb8831990e --- /dev/null +++ b/evals/registry/solvers/gemini.yaml @@ -0,0 +1,23 @@ + +# ------------------ +# gemini-pro +# ------------------ + +# generation tasks + +generation/direct/gemini-pro: + class: evals.solvers.providers.google.gemini_solver:GeminiSolver + args: + model_name: gemini-pro + +generation/cot/gemini-pro: + class: evals.solvers.nested.cot_solver:CoTSolver + args: + cot_solver: + class: evals.solvers.providers.google.gemini_solver:GeminiSolver + args: + model_name: gemini-pro + extract_solver: + class: evals.solvers.providers.google.gemini_solver:GeminiSolver + args: + model_name: gemini-pro diff --git a/evals/solvers/providers/google/gemini_solver.py b/evals/solvers/providers/google/gemini_solver.py new file mode 100644 index 0000000000..33a8a93a25 --- /dev/null +++ b/evals/solvers/providers/google/gemini_solver.py @@ -0,0 +1,211 @@ +import copy +import os +from dataclasses import asdict, dataclass +from typing import Any, Dict, Union + +import google.api_core.exceptions +import google.generativeai as genai +from google.generativeai.client import get_default_generative_client + +from evals.record import record_sampling +from evals.solvers.solver import Solver, SolverResult +from evals.task_state import Message, TaskState +from evals.utils.api_utils import create_retrying + +# Load API key from environment variable +API_KEY = os.environ.get("GEMINI_API_KEY") +genai.configure(api_key=API_KEY) + +SAFETY_SETTINGS = [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_NONE", + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_NONE", + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_NONE", + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_NONE", + }, +] +GEMINI_RETRY_EXCEPTIONS = ( + google.api_core.exceptions.RetryError, + google.api_core.exceptions.TooManyRequests, + google.api_core.exceptions.ResourceExhausted, +) + + +# TODO: Could we just use google's own types? +# e.g. google.generativeai.types.content_types.ContentType +@dataclass +class GoogleMessage: + role: str + parts: list[str] + + def to_dict(self): + return asdict(self) + + @staticmethod + def from_evals_message(msg: Message): + valid_roles = {"user", "model"} + to_google_role = { + "system": "user", # Google doesn't have a "system" role + "user": "user", + "assistant": "model", + } + gmsg = GoogleMessage( + role=to_google_role.get(msg.role, msg.role), + parts=[msg.content], + ) + assert gmsg.role in valid_roles, f"Invalid role: {gmsg.role}" + return gmsg + + +class GeminiSolver(Solver): + """ + A solver class that uses Google's Gemini API to generate responses. + """ + + def __init__( + self, + model_name: str, + generation_config: Dict[str, Any] = {}, + postprocessors: list[str] = [], + registry: Any = None, + ): + super().__init__(postprocessors=postprocessors) + + self.model_name = model_name + self.gen_config = genai.GenerationConfig(**generation_config) + + # We manually define the client. This is normally defined automatically when calling + # the API, but it isn't thread-safe, so we anticipate its creation here + self.glm_client = get_default_generative_client() + + @property + def model(self) -> str: + return self.model_name + + def _solve( + self, + task_state: TaskState, + **kwargs, + ) -> SolverResult: + msgs = [ + Message(role="user", content=task_state.task_description), + ] + task_state.messages + gmsgs = self._convert_msgs_to_google_format(msgs) + gmsgs = [msg.to_dict() for msg in gmsgs] + try: + glm_model = genai.GenerativeModel(model_name=self.model_name) + glm_model._client = self.glm_client + + gen_content_resp = create_retrying( + glm_model.generate_content, + retry_exceptions=GEMINI_RETRY_EXCEPTIONS, + **{ + "contents": gmsgs, + "generation_config": self.gen_config, + "safety_settings": SAFETY_SETTINGS, + }, + ) + if gen_content_resp.prompt_feedback.block_reason: + # Blocked by safety filters + solver_result = SolverResult( + str(gen_content_resp.prompt_feedback), + error=gen_content_resp.prompt_feedback, + ) + else: + # Get text response + solver_result = SolverResult( + gen_content_resp.text, + error=gen_content_resp.prompt_feedback, + ) + except (google.api_core.exceptions.GoogleAPIError,) as e: + solver_result = SolverResult( + e.message, + error=e, + ) + except ValueError as e: + # TODO: Why does this error ever occur and how can we handle it better? + # (See google/generativeai/types/generation_types.py for the triggers) + known_errors = [ + "The `response.text` quick accessor", + "The `response.parts` quick accessor", + ] + if any(err in str(e) for err in known_errors): + solver_result = SolverResult( + str(e), + error=e, + ) + else: + raise e + + record_sampling( + prompt=msgs, + sampled=[solver_result.output], + model=self.model, + ) + return solver_result + + @staticmethod + def _convert_msgs_to_google_format(msgs: list[Message]) -> list[GoogleMessage]: + """ + Gemini API requires that the message list has + - Roles as 'user' or 'model' + - Alternating 'user' and 'model' messages + - Ends with a 'user' message + """ + # Enforce valid roles + gmsgs = [] + for msg in msgs: + gmsg = GoogleMessage.from_evals_message(msg) + gmsgs.append(gmsg) + assert gmsg.role in {"user", "model"}, f"Invalid role: {gmsg.role}" + + # Enforce alternating messages + # e.g. [user1, user2, model1, user3] -> [user12, model1, user3] + std_msgs = [] + for msg in gmsgs: + if len(std_msgs) > 0 and msg.role == std_msgs[-1].role: + # Merge consecutive messages from the same role + std_msgs[-1].parts.extend(msg.parts) + # The API seems to expect a single-element list of strings (???) so we join the + # parts into a list containing a single string + std_msgs[-1].parts = ["\n".join(std_msgs[-1].parts)] + else: + # Proceed as normal + std_msgs.append(msg) + + # Enforce last message is from the user + assert std_msgs[-1].role == "user", "Last message must be from the user" + return std_msgs + + @property + def name(self) -> str: + return self.model + + @property + def model_version(self) -> Union[str, dict]: + return self.model + + def __deepcopy__(self, memo): + """ + Deepcopy everything except for self.glm_client, which is instead shared across all copies + """ + cls = self.__class__ + result = cls.__new__(cls) + + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k != "glm_client": + setattr(result, k, copy.deepcopy(v, memo)) + + result.glm_client = self.glm_client + return result diff --git a/evals/solvers/providers/google/gemini_solver_test.py b/evals/solvers/providers/google/gemini_solver_test.py new file mode 100644 index 0000000000..9586c5f8f2 --- /dev/null +++ b/evals/solvers/providers/google/gemini_solver_test.py @@ -0,0 +1,71 @@ +import os + +import pytest + +from evals.record import DummyRecorder +from evals.solvers.providers.google.gemini_solver import GeminiSolver, GoogleMessage +from evals.task_state import Message, TaskState + +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" +MODEL_NAME = "gemini-pro" + + +@pytest.fixture +def dummy_recorder(): + recorder = DummyRecorder(None) # type: ignore + with recorder.as_default_recorder("x"): + yield recorder + + +@pytest.fixture +def gemini_solver(): + os.environ["EVALS_SEQUENTIAL"] = "1" # TODO: Remove after fixing threading issue + solver = GeminiSolver( + model_name=MODEL_NAME, + ) + return solver + + +@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="API tests are wasteful to run on every commit.") +def test_solver(dummy_recorder, gemini_solver): + """ + Test that the solver generates a response coherent with the message history + while following the instructions from the task description. + """ + solver = gemini_solver + + answer = "John Doe" + task_state = TaskState( + task_description=f"When you are asked for your name, respond with '{answer}' (without quotes).", + messages=[ + Message(role="user", content="What is 2 + 2?"), + Message(role="assistant", content="4"), + Message(role="user", content="What is your name?"), + ], + ) + + solver_res = solver(task_state=task_state) + assert solver_res.output == answer, f"Expected '{answer}', but got {solver_res.output}" + + +def test_message_format(): + """ + Test that messages in our evals format is correctly converted to the format + expected by Gemini. + """ + + messages = [ + Message(role="system", content="You are a great mathematician."), + Message(role="user", content="What is 2 + 2?"), + Message(role="assistant", content="5"), + Message(role="user", content="That's incorrect. What is 2 + 2?"), + ] + + gmessages = GeminiSolver._convert_msgs_to_google_format(messages) + expected = [ + GoogleMessage(role="user", parts=["You are a great mathematician.\nWhat is 2 + 2?"]), + GoogleMessage(role="model", parts=["5"]), + GoogleMessage(role="user", parts=["That's incorrect. What is 2 + 2?"]), + ] + + assert gmessages == expected, f"Expected {expected}, but got {gmessages}" diff --git a/evals/solvers/providers/google/requirements.txt b/evals/solvers/providers/google/requirements.txt new file mode 100644 index 0000000000..7f052b42ec --- /dev/null +++ b/evals/solvers/providers/google/requirements.txt @@ -0,0 +1 @@ +google-generativeai \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index b383b3496a..6b32152584 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,8 @@ dependencies = [ "gymnasium", "networkx", "chess", - "anthropic" + "anthropic", + "google-generativeai", ] [project.urls]