Skip to content

Commit

Permalink
Merge pull request #2 from plasma-umass/bedrock_support
Browse files Browse the repository at this point in the history
Bedrock support [WIP]
  • Loading branch information
jaltmayerpizzorno authored Mar 23, 2024
2 parents fc0e0c5 + c26d957 commit 9c546ea
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 21 deletions.
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ classifiers = [
requires-python = ">=3.10"
dependencies = [
"asyncio",
"openai==0.28",
"openai",
"tiktoken",
"aiolimiter",
"tqdm",
"llm_utils",
"slipcover>=1.0.3"
"slipcover>=1.0.3",
"litellm>=1.33.1"
]

[project.scripts]
Expand Down
71 changes: 53 additions & 18 deletions src/coverup/coverup.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
import asyncio
import openai
import json
import litellm # type: ignore
import logging
import openai
import subprocess
from pathlib import Path
import typing as T
import re
import sys
import typing as T

from pathlib import Path
from datetime import datetime

from .llm import *
from .segment import *
from .testrunner import *


PREFIX = 'coverup'
DEFAULT_MODEL='gpt-4-1106-preview'

# Turn off most logging
litellm.set_verbose = False
logging.getLogger().setLevel(logging.ERROR)
# Ignore unavailable parameters
litellm.drop_params=True

def parse_args(args=None):
import argparse
Expand All @@ -40,7 +48,7 @@ def Path_dir(value):
ap.add_argument('--no-checkpoint', action='store_const', const=None, dest='checkpoint', default=argparse.SUPPRESS,
help=f'disables checkpoint')

ap.add_argument('--model', type=str, default=DEFAULT_MODEL,
ap.add_argument('--model', type=str,
help='OpenAI model to use')

ap.add_argument('--model-temperature', type=str, default=0,
Expand Down Expand Up @@ -103,6 +111,7 @@ def positive_int(value):

def test_file_path(test_seq: int) -> Path:
"""Returns the Path for a test's file, given its sequence number."""
global args
return args.tests_dir / f"test_{PREFIX}_{test_seq}.py"


Expand Down Expand Up @@ -413,11 +422,9 @@ async def do_chat(seg: CodeSegment, completion: dict) -> str:
log_write(seg, f"Error: too many tokens for rate limit ({e})")
return None # gives up this segment

return await openai.ChatCompletion.acreate(**completion)
return await litellm.acreate(**completion)

except (openai.error.ServiceUnavailableError,
openai.error.RateLimitError,
openai.error.Timeout) as e:
except (openai.RateLimitError, openai.APITimeoutError) as e:

# This message usually indicates out of money in account
if 'You exceeded your current quota' in str(e):
Expand All @@ -432,13 +439,12 @@ async def do_chat(seg: CodeSegment, completion: dict) -> str:
state.inc_counter('R')
await asyncio.sleep(sleep_time)

except openai.error.InvalidRequestError as e:
except openai.BadRequestError as e:
# usually "maximum context length" XXX check for this?
log_write(seg, f"Error: {type(e)} {e}")
return None # gives up this segment

except (openai.error.APIConnectionError,
openai.error.APIError) as e:
except (ConnectionError) as e:
log_write(seg, f"Error: {type(e)} {e}")
# usually a server-side error... just retry right away
state.inc_counter('R')
Expand Down Expand Up @@ -589,6 +595,7 @@ def add_to_pythonpath(source_dir: Path):


def main():

from collections import defaultdict
import os

Expand All @@ -612,14 +619,42 @@ def main():
token_rate_limit = AsyncLimiter(*limit)
# TODO also add request limit, and use 'await asyncio.gather(t.acquire(tokens), r.acquire())' to acquire both

if 'OPENAI_API_KEY' not in os.environ:
print("Please place your OpenAI key in an environment variable named OPENAI_API_KEY and try again.")
return 1

openai.key=os.environ['OPENAI_API_KEY']
if 'OPENAI_ORGANIZATION' in os.environ:
openai.organization=os.environ['OPENAI_ORGANIZATION']
# Check for an API key for OpenAI or Amazon Bedrock.
if 'OPENAI_API_KEY' not in os.environ:
if not all(x in os.environ for x in ['AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY', 'AWS_REGION_NAME']):
print("You need a key (or keys) from an AI service to use CoverUp.")
print()
print("OpenAI:")
print(" You can get a key here: https://platform.openai.com/api-keys")
print(" Set the environment variable OPENAI_API_KEY to your key value:")
print(" export OPENAI_API_KEY=<your key>")
print()
print()
print("Bedrock:")
print(" To use Bedrock, you need an AWS account.")
print(" Set the following environment variables:")
print(" export AWS_ACCESS_KEY_ID=<your key id>")
print(" export AWS_SECRET_ACCESS_KEY=<your secret key>")
print(" export AWS_REGION_NAME=us-west-2")
print(" You also need to request access to Claude:")
print(
" https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html#manage-model-access"
)
print()
return 1

if 'OPENAI_API_KEY' in os.environ:
if not args.model:
# args.model = "openai/gpt-4"
args.model = "openai/gpt-4-1106-preview"
# openai.key=os.environ['OPENAI_API_KEY']
#if 'OPENAI_ORGANIZATION' in os.environ:
# openai.organization=os.environ['OPENAI_ORGANIZATION']
else:
# args.model = "bedrock/anthropic.claude-v2:1"
if not args.model:
args.model = "bedrock/anthropic.claude-3-sonnet-20240229-v1:0"
log_write('startup', f"Command: {' '.join(sys.argv)}")

# --- (1) load or measure initial coverage, figure out segmentation ---
Expand Down
6 changes: 5 additions & 1 deletion src/coverup/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ async def subprocess_run(args: str, check: bool = False, timeout: T.Optional[int
except asyncio.TimeoutError:
process.terminate()
await process.wait()
raise subprocess.TimeoutExpired(args, timeout) from None
if timeout:
timeout_f = float(timeout)
else:
timeout_f = 0.0
raise subprocess.TimeoutExpired(args, timeout_f) from None

if check and process.returncode != 0:
raise subprocess.CalledProcessError(process.returncode, args, output=output)
Expand Down

0 comments on commit 9c546ea

Please sign in to comment.