Skip to content

Commit

Permalink
general create_retrying func for all solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
ojaffe committed Mar 21, 2024
1 parent e30e141 commit c08488e
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 84 deletions.
44 changes: 39 additions & 5 deletions evals/completion_fns/openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
from typing import Any, Optional, Union

import openai
from openai import OpenAI

from evals.api import CompletionFn, CompletionResult
Expand All @@ -12,12 +14,44 @@
Prompt,
)
from evals.record import record_sampling
from evals.utils.api_utils import (
openai_chat_completion_create_retrying,
openai_completion_create_retrying,
from evals.utils.api_utils import create_retrying

OPENAI_TIMEOUT_EXCEPTIONS = (
openai.RateLimitError,
openai.APIConnectionError,
openai.APITimeoutError,
openai.InternalServerError,
)


def openai_completion_create_retrying(client: OpenAI, *args, **kwargs):
"""
Helper function for creating a completion.
`args` and `kwargs` match what is accepted by `openai.Completion.create`.
"""
result = create_retrying(
client.completions.create, retry_exceptions=OPENAI_TIMEOUT_EXCEPTIONS, *args, **kwargs
)
if "error" in result:
logging.warning(result)
raise openai.APIError(result["error"])
return result


def openai_chat_completion_create_retrying(client: OpenAI, *args, **kwargs):
"""
Helper function for creating a completion.
`args` and `kwargs` match what is accepted by `openai.Completion.create`.
"""
result = create_retrying(
client.chat.completions.create, retry_exceptions=OPENAI_TIMEOUT_EXCEPTIONS, *args, **kwargs
)
if "error" in result:
logging.warning(result)
raise openai.APIError(result["error"])
return result


class OpenAIBaseCompletionResult(CompletionResult):
def __init__(self, raw_data: Any, prompt: Any):
self.raw_data = raw_data
Expand Down Expand Up @@ -82,7 +116,7 @@ def __call__(
openai_create_prompt: OpenAICreatePrompt = prompt.to_formatted_prompt()

result = openai_completion_create_retrying(
OpenAI(api_key=self.api_key, base_url=self.api_base),
client=OpenAI(api_key=self.api_key, base_url=self.api_base),
model=self.model,
prompt=openai_create_prompt,
**{**kwargs, **self.extra_options},
Expand Down Expand Up @@ -127,7 +161,7 @@ def __call__(
openai_create_prompt: OpenAICreateChatPrompt = prompt.to_formatted_prompt()

result = openai_chat_completion_create_retrying(
OpenAI(api_key=self.api_key, base_url=self.api_base),
client=OpenAI(api_key=self.api_key, base_url=self.api_base),
model=self.model,
messages=openai_create_prompt,
**{**kwargs, **self.extra_options},
Expand Down
36 changes: 15 additions & 21 deletions evals/solvers/providers/anthropic/anthropic_solver.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
from typing import Any, Optional, Union

from evals.solvers.solver import Solver, SolverResult
from evals.task_state import TaskState, Message
from evals.record import record_sampling
from evals.utils.api_utils import request_with_timeout

import anthropic
from anthropic import Anthropic
from anthropic.types import ContentBlock, MessageParam, Usage
import backoff

from evals.record import record_sampling
from evals.solvers.solver import Solver, SolverResult
from evals.task_state import Message, TaskState
from evals.utils.api_utils import create_retrying

oai_to_anthropic_role = {
"system": "user",
"user": "user",
"assistant": "assistant",
}
ANTHROPIC_TIMEOUT_EXCEPTIONS = (
anthropic.RateLimitError,
anthropic.APIConnectionError,
anthropic.APITimeoutError,
anthropic.InternalServerError,
)


class AnthropicSolver(Solver):
Expand Down Expand Up @@ -59,9 +64,7 @@ def _solve(self, task_state: TaskState, **kwargs) -> SolverResult:
)

# for logging purposes: prepend the task desc to the orig msgs as a system message
orig_msgs.insert(
0, Message(role="system", content=task_state.task_description).to_dict()
)
orig_msgs.insert(0, Message(role="system", content=task_state.task_description).to_dict())
record_sampling(
prompt=orig_msgs, # original message format, supported by our logviz
sampled=[solver_result.output],
Expand Down Expand Up @@ -113,23 +116,14 @@ def _convert_msgs_to_anthropic_format(msgs: list[Message]) -> list[MessageParam]
return alt_msgs


@backoff.on_exception(
wait_gen=backoff.expo,
exception=(
anthropic.RateLimitError,
anthropic.APIConnectionError,
anthropic.APITimeoutError,
anthropic.InternalServerError,
),
max_value=60,
factor=1.5,
)
def anthropic_create_retrying(client: Anthropic, *args, **kwargs):
"""
Helper function for creating a backoff-retry enabled message request.
`args` and `kwargs` match what is accepted by `client.messages.create`.
"""
result = request_with_timeout(client.messages.create, *args, **kwargs)
result = create_retrying(
client.messages.create, retry_exceptions=ANTHROPIC_TIMEOUT_EXCEPTIONS, *args, **kwargs
)
if "error" in result:
raise Exception(result["error"])
return result
Expand Down
65 changes: 7 additions & 58 deletions evals/utils/api_utils.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,22 @@
"""
This file defines various helper functions for interacting with the OpenAI API.
"""
import concurrent
import logging
import os

import backoff
import openai
from openai import OpenAI

EVALS_THREAD_TIMEOUT = float(os.environ.get("EVALS_THREAD_TIMEOUT", "40"))
logging.getLogger("httpx").setLevel(logging.WARNING) # suppress "OK" logs from openai API calls


@backoff.on_exception(
@backoff.on_predicate(
wait_gen=backoff.expo,
exception=(
openai.RateLimitError,
openai.APIConnectionError,
openai.APITimeoutError,
openai.InternalServerError,
),
max_value=60,
factor=1.5,
)
def openai_completion_create_retrying(client: OpenAI, *args, **kwargs):
def create_retrying(func: callable, retry_exceptions: tuple[Exception], *args, **kwargs):
"""
Helper function for creating a completion.
`args` and `kwargs` match what is accepted by `openai.Completion.create`.
Retries given function if one of given exceptions is raised
"""
result = client.completions.create(*args, **kwargs)
if "error" in result:
logging.warning(result)
raise openai.error.APIError(result["error"])
return result


def request_with_timeout(func, *args, timeout=EVALS_THREAD_TIMEOUT, **kwargs):
"""
Worker thread for making a single request within allotted time.
"""
while True:
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(func, *args, **kwargs)
try:
result = future.result(timeout=timeout)
return result
except concurrent.futures.TimeoutError:
continue


@backoff.on_exception(
wait_gen=backoff.expo,
exception=(
openai.RateLimitError,
openai.APIConnectionError,
openai.APITimeoutError,
openai.InternalServerError,
),
max_value=60,
factor=1.5,
)
def openai_chat_completion_create_retrying(client: OpenAI, *args, **kwargs):
"""
Helper function for creating a chat completion.
`args` and `kwargs` match what is accepted by `openai.ChatCompletion.create`.
"""
result = request_with_timeout(client.chat.completions.create, *args, **kwargs)
if "error" in result:
logging.warning(result)
raise openai.error.APIError(result["error"])
return result
try:
return func(*args, **kwargs)
except retry_exceptions:
return False

0 comments on commit c08488e

Please sign in to comment.