Skip to content

Commit

Permalink
Unified create_retrying for all solvers (#1501)
Browse files Browse the repository at this point in the history
We're now implementing solvers for new APIs we're calling (Anthropic,
Gemini, ...). Each solver was implementing the same logic for backing
off and retrying when the API query limit was hit. This PR created a
generic create_retrying function, which retries when specific exceptions
are raised. These exceptions are passed as arguments.

This uses the changes from #1482
  • Loading branch information
ojaffe authored Mar 26, 2024
1 parent ac44aae commit 150dcb9
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 79 deletions.
40 changes: 37 additions & 3 deletions evals/completion_fns/openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
from typing import Any, Optional, Union

import openai
from openai import OpenAI

from evals.api import CompletionFn, CompletionResult
Expand All @@ -12,12 +14,44 @@
Prompt,
)
from evals.record import record_sampling
from evals.utils.api_utils import (
openai_chat_completion_create_retrying,
openai_completion_create_retrying,
from evals.utils.api_utils import create_retrying

OPENAI_TIMEOUT_EXCEPTIONS = (
openai.RateLimitError,
openai.APIConnectionError,
openai.APITimeoutError,
openai.InternalServerError,
)


def openai_completion_create_retrying(client: OpenAI, *args, **kwargs):
"""
Helper function for creating a completion.
`args` and `kwargs` match what is accepted by `openai.Completion.create`.
"""
result = create_retrying(
client.completions.create, retry_exceptions=OPENAI_TIMEOUT_EXCEPTIONS, *args, **kwargs
)
if "error" in result:
logging.warning(result)
raise openai.APIError(result["error"])
return result


def openai_chat_completion_create_retrying(client: OpenAI, *args, **kwargs):
"""
Helper function for creating a completion.
`args` and `kwargs` match what is accepted by `openai.Completion.create`.
"""
result = create_retrying(
client.chat.completions.create, retry_exceptions=OPENAI_TIMEOUT_EXCEPTIONS, *args, **kwargs
)
if "error" in result:
logging.warning(result)
raise openai.APIError(result["error"])
return result


class OpenAIBaseCompletionResult(CompletionResult):
def __init__(self, raw_data: Any, prompt: Any):
self.raw_data = raw_data
Expand Down
36 changes: 15 additions & 21 deletions evals/solvers/providers/anthropic/anthropic_solver.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
from typing import Any, Optional, Union

from evals.solvers.solver import Solver, SolverResult
from evals.task_state import TaskState, Message
from evals.record import record_sampling
from evals.utils.api_utils import request_with_timeout

import anthropic
from anthropic import Anthropic
from anthropic.types import ContentBlock, MessageParam, Usage
import backoff

from evals.record import record_sampling
from evals.solvers.solver import Solver, SolverResult
from evals.task_state import Message, TaskState
from evals.utils.api_utils import create_retrying

oai_to_anthropic_role = {
"system": "user",
"user": "user",
"assistant": "assistant",
}
ANTHROPIC_TIMEOUT_EXCEPTIONS = (
anthropic.RateLimitError,
anthropic.APIConnectionError,
anthropic.APITimeoutError,
anthropic.InternalServerError,
)


class AnthropicSolver(Solver):
Expand Down Expand Up @@ -59,9 +64,7 @@ def _solve(self, task_state: TaskState, **kwargs) -> SolverResult:
)

# for logging purposes: prepend the task desc to the orig msgs as a system message
orig_msgs.insert(
0, Message(role="system", content=task_state.task_description).to_dict()
)
orig_msgs.insert(0, Message(role="system", content=task_state.task_description).to_dict())
record_sampling(
prompt=orig_msgs, # original message format, supported by our logviz
sampled=[solver_result.output],
Expand Down Expand Up @@ -113,23 +116,14 @@ def _convert_msgs_to_anthropic_format(msgs: list[Message]) -> list[MessageParam]
return alt_msgs


@backoff.on_exception(
wait_gen=backoff.expo,
exception=(
anthropic.RateLimitError,
anthropic.APIConnectionError,
anthropic.APITimeoutError,
anthropic.InternalServerError,
),
max_value=60,
factor=1.5,
)
def anthropic_create_retrying(client: Anthropic, *args, **kwargs):
"""
Helper function for creating a backoff-retry enabled message request.
`args` and `kwargs` match what is accepted by `client.messages.create`.
"""
result = request_with_timeout(client.messages.create, *args, **kwargs)
result = create_retrying(
client.messages.create, retry_exceptions=ANTHROPIC_TIMEOUT_EXCEPTIONS, *args, **kwargs
)
if "error" in result:
raise Exception(result["error"])
return result
Expand Down
62 changes: 7 additions & 55 deletions evals/utils/api_utils.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,22 @@
"""
This file defines various helper functions for interacting with the OpenAI API.
"""
import logging
import os

import backoff
import openai
from openai import OpenAI

EVALS_THREAD_TIMEOUT = float(os.environ.get("EVALS_THREAD_TIMEOUT", "40"))
logging.getLogger("httpx").setLevel(logging.WARNING) # suppress "OK" logs from openai API calls


@backoff.on_exception(
@backoff.on_predicate(
wait_gen=backoff.expo,
exception=(
openai.RateLimitError,
openai.APIConnectionError,
openai.APITimeoutError,
openai.InternalServerError,
),
max_value=60,
factor=1.5,
)
def openai_completion_create_retrying(client: OpenAI, *args, **kwargs):
def create_retrying(func: callable, retry_exceptions: tuple[Exception], *args, **kwargs):
"""
Helper function for creating a completion.
`args` and `kwargs` match what is accepted by `openai.Completion.create`.
Retries given function if one of given exceptions is raised
"""
result = client.completions.create(*args, **kwargs)
if "error" in result:
logging.warning(result)
raise openai.error.APIError(result["error"])
return result


def request_with_timeout(func, *args, timeout=EVALS_THREAD_TIMEOUT, **kwargs):
"""
Function for making a single request within allotted time.
"""
while True:
try:
result = func(*args, timeout=timeout, **kwargs)
return result
except openai.APITimeoutError as e:
continue


@backoff.on_exception(
wait_gen=backoff.expo,
exception=(
openai.RateLimitError,
openai.APIConnectionError,
openai.APITimeoutError,
openai.InternalServerError,
),
max_value=60,
factor=1.5,
)
def openai_chat_completion_create_retrying(client: OpenAI, *args, **kwargs):
"""
Helper function for creating a chat completion.
`args` and `kwargs` match what is accepted by `openai.ChatCompletion.create`.
"""
result = request_with_timeout(client.chat.completions.create, *args, **kwargs)
if "error" in result:
logging.warning(result)
raise openai.error.APIError(result["error"])
return result
try:
return func(*args, **kwargs)
except retry_exceptions:
return False

0 comments on commit 150dcb9

Please sign in to comment.