Skip to content

Commit

Permalink
feedback
Browse files Browse the repository at this point in the history
Signed-off-by: Adrian Cole <adrian.cole@elastic.co>
  • Loading branch information
codefromthecrypt committed Nov 4, 2024
1 parent f53c503 commit 224d729
Show file tree
Hide file tree
Showing 12 changed files with 140 additions and 118 deletions.
9 changes: 4 additions & 5 deletions packages/exchange/src/exchange/providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@
from exchange import Message, Tool
from exchange.content import Text, ToolResult, ToolUse
from exchange.providers.base import Provider, Usage
from exchange.providers.utils import get_env_url
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status, raise_for_status
from exchange.langfuse_wrapper import observe_wrapper

ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages"

retry_procedure = retry(
wait=wait_fixed(2),
stop=stop_after_attempt(2),
Expand All @@ -24,6 +21,8 @@ class AnthropicProvider(Provider):
"""Provides chat completions for models hosted directly by Anthropic."""

PROVIDER_NAME = "anthropic"
BASE_URL_ENV_VAR = "ANTHROPIC_HOST"
BASE_URL_DEFAULT = "https://api.anthropic.com/v1/messages"
REQUIRED_ENV_VARS = ["ANTHROPIC_API_KEY"]

def __init__(self, client: httpx.Client) -> None:
Expand All @@ -32,7 +31,7 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: type["AnthropicProvider"]) -> "AnthropicProvider":
cls.check_env_vars()
url = get_env_url("ANTHROPIC_HOST", ANTHROPIC_HOST)
url = httpx.URL(os.environ.get(cls.BASE_URL_ENV_VAR, cls.BASE_URL_DEFAULT))
key = os.environ.get("ANTHROPIC_API_KEY")
client = httpx.Client(
base_url=url,
Expand Down Expand Up @@ -165,5 +164,5 @@ def recommended_models() -> tuple[str, str]:

@retry_procedure
def _post(self, payload: dict) -> httpx.Response:
response = self.client.post(ANTHROPIC_HOST, json=payload)
response = self.client.post("", json=payload)
return raise_for_status(response).json()
5 changes: 2 additions & 3 deletions packages/exchange/src/exchange/providers/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import os

from exchange.providers import OpenAiProvider
from exchange.providers.utils import get_env_url


class AzureProvider(OpenAiProvider):
"""Provides chat completions for models hosted by the Azure OpenAI Service."""

PROVIDER_NAME = "azure"
BASE_URL_ENV_VAR = "AZURE_CHAT_COMPLETIONS_HOST_NAME"
REQUIRED_ENV_VARS = [
"AZURE_CHAT_COMPLETIONS_HOST_NAME",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION",
"AZURE_CHAT_COMPLETIONS_KEY",
Expand All @@ -22,7 +21,7 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: type["AzureProvider"]) -> "AzureProvider":
cls.check_env_vars()
url = get_env_url("AZURE_CHAT_COMPLETIONS_HOST_NAME")
url = httpx.URL(os.environ.get(cls.BASE_URL_ENV_VAR, cls.BASE_URL_DEFAULT))
deployment_name = os.environ.get("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME")
api_version = os.environ.get("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION")
key = os.environ.get("AZURE_CHAT_COMPLETIONS_KEY")
Expand Down
17 changes: 16 additions & 1 deletion packages/exchange/src/exchange/providers/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import httpx
import os
from abc import ABC, abstractmethod
from attrs import define, field
Expand All @@ -22,6 +23,8 @@ def __init__(self, provider_cls: str) -> None:

class Provider(ABC):
PROVIDER_NAME: str
BASE_URL_ENV_VAR: str = ""
BASE_URL_DEFAULT: str = ""
REQUIRED_ENV_VARS: list[str] = []

@classmethod
Expand All @@ -32,11 +35,23 @@ def from_env(cls: type["Provider"]) -> "Provider":

@classmethod
def check_env_vars(cls: type["Provider"], instructions_url: Optional[str] = None) -> None:
provider = cls.PROVIDER_NAME
missing_vars = [x for x in cls.REQUIRED_ENV_VARS if x not in os.environ]

url_var = cls.BASE_URL_ENV_VAR
if url_var:
val = os.environ.get(url_var, cls.BASE_URL_DEFAULT)
if not val:
raise KeyError(url_var)
else:
url = httpx.URL(val)

if url.scheme not in ["http", "https"]:
raise ValueError(f"Expected {url_var} to be a 'http' or 'https' url: {val}")

if missing_vars:
env_vars = ", ".join(missing_vars)
raise MissingProviderEnvVariableError(env_vars, cls.PROVIDER_NAME, instructions_url)
raise MissingProviderEnvVariableError(env_vars, provider, instructions_url)

@abstractmethod
def complete(
Expand Down
10 changes: 4 additions & 6 deletions packages/exchange/src/exchange/providers/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from exchange.message import Message
from exchange.providers.base import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import raise_for_status, retry_if_status, get_env_url
from exchange.providers.utils import raise_for_status, retry_if_status
from exchange.providers.utils import (
messages_to_openai_spec,
openai_response_to_message,
Expand All @@ -31,10 +31,8 @@ class DatabricksProvider(Provider):
"""

PROVIDER_NAME = "databricks"
REQUIRED_ENV_VARS = [
"DATABRICKS_HOST",
"DATABRICKS_TOKEN",
]
BASE_URL_ENV_VAR = "DATABRICKS_HOST"
REQUIRED_ENV_VARS = ["DATABRICKS_TOKEN"]
instructions_url = "https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields"

def __init__(self, client: httpx.Client) -> None:
Expand All @@ -43,7 +41,7 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: type["DatabricksProvider"]) -> "DatabricksProvider":
cls.check_env_vars(cls.instructions_url)
url = get_env_url("DATABRICKS_HOST")
url = httpx.URL(os.environ.get(cls.BASE_URL_ENV_VAR))
key = os.environ.get("DATABRICKS_TOKEN")
client = httpx.Client(
base_url=url,
Expand Down
9 changes: 4 additions & 5 deletions packages/exchange/src/exchange/providers/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@
from exchange.content import Text, ToolResult, ToolUse
from exchange.providers.base import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import raise_for_status, retry_if_status, get_env_url, encode_image
from exchange.providers.utils import raise_for_status, retry_if_status, encode_image
from exchange.langfuse_wrapper import observe_wrapper


GOOGLE_HOST = "https://generativelanguage.googleapis.com/v1beta"

retry_procedure = retry(
wait=wait_fixed(2),
stop=stop_after_attempt(2),
Expand All @@ -24,6 +21,8 @@ class GoogleProvider(Provider):
"""Provides chat completions for models hosted by Google, including Gemini and other experimental models."""

PROVIDER_NAME = "google"
BASE_URL_ENV_VAR = "GOOGLE_HOST"
BASE_URL_DEFAULT = "https://generativelanguage.googleapis.com/v1beta"
REQUIRED_ENV_VARS = ["GOOGLE_API_KEY"]
instructions_url = "https://ai.google.dev/gemini-api/docs/api-key"

Expand All @@ -33,7 +32,7 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: type["GoogleProvider"]) -> "GoogleProvider":
cls.check_env_vars(cls.instructions_url)
url = get_env_url("GOOGLE_HOST", GOOGLE_HOST)
url = httpx.URL(os.environ.get(cls.BASE_URL_ENV_VAR, cls.BASE_URL_DEFAULT))
key = os.environ.get("GOOGLE_API_KEY")
client = httpx.Client(
base_url=url,
Expand Down
7 changes: 3 additions & 4 deletions packages/exchange/src/exchange/providers/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,11 @@
openai_single_message_context_length_exceeded,
raise_for_status,
tools_to_openai_spec,
get_env_url,
)
from exchange.tool import Tool
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status

GROQ_HOST = "https://api.groq.com/openai/"

retry_procedure = retry(
wait=wait_fixed(5),
stop=stop_after_attempt(5),
Expand All @@ -31,6 +28,8 @@ class GroqProvider(Provider):
"""Provides chat completions for models hosted directly by OpenAI."""

PROVIDER_NAME = "groq"
BASE_URL_ENV_VAR = "GROQ_HOST"
BASE_URL_DEFAULT = "https://api.groq.com/openai/"
REQUIRED_ENV_VARS = ["GROQ_API_KEY"]
instructions_url = "https://console.groq.com/docs/quickstart"

Expand All @@ -40,7 +39,7 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: type["GroqProvider"]) -> "GroqProvider":
cls.check_env_vars(cls.instructions_url)
url = get_env_url("GROQ_HOST", GROQ_HOST)
url = httpx.URL(os.environ.get(cls.BASE_URL_ENV_VAR, cls.BASE_URL_DEFAULT))
key = os.environ.get("GROQ_API_KEY")

client = httpx.Client(
Expand Down
6 changes: 3 additions & 3 deletions packages/exchange/src/exchange/providers/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@

from typing import Type
from exchange.providers.openai import OpenAiProvider
from exchange.providers.utils import get_env_url

OLLAMA_HOST = "http://localhost:11434/"
OLLAMA_MODEL = "qwen2.5"


Expand All @@ -26,14 +24,16 @@ class OllamaProvider(OpenAiProvider):
requires: {{}}
"""
PROVIDER_NAME = "ollama"
BASE_URL_ENV_VAR = "OLLAMA_HOST"
BASE_URL_DEFAULT = "http://localhost:11434/"

def __init__(self, client: httpx.Client) -> None:
print("PLEASE NOTE: the ollama provider is experimental, use with care")
super().__init__(client)

@classmethod
def from_env(cls: Type["OllamaProvider"]) -> "OllamaProvider":
ollama_url = get_env_url("OLLAMA_HOST", OLLAMA_HOST)
ollama_url = httpx.URL(os.environ.get(cls.BASE_URL_ENV_VAR, cls.BASE_URL_DEFAULT))
timeout = httpx.Timeout(60 * 10)

# from_env is expected to fail if required ENV variables are not
Expand Down
6 changes: 3 additions & 3 deletions packages/exchange/src/exchange/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@
openai_single_message_context_length_exceeded,
raise_for_status,
tools_to_openai_spec,
get_env_url,
)
from exchange.tool import Tool
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
from exchange.langfuse_wrapper import observe_wrapper

OPENAI_HOST = "https://api.openai.com/"

retry_procedure = retry(
wait=wait_fixed(2),
Expand All @@ -31,6 +29,8 @@ class OpenAiProvider(Provider):
"""Provides chat completions for models hosted directly by OpenAI."""

PROVIDER_NAME = "openai"
BASE_URL_ENV_VAR = "OPENAI_HOST"
BASE_URL_DEFAULT = "https://api.openai.com/"
REQUIRED_ENV_VARS = ["OPENAI_API_KEY"]
instructions_url = "https://platform.openai.com/docs/api-reference/api-keys"

Expand All @@ -40,7 +40,7 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: type["OpenAiProvider"]) -> "OpenAiProvider":
cls.check_env_vars(cls.instructions_url)
url = get_env_url("OPENAI_HOST", OPENAI_HOST)
url = httpx.URL(os.environ.get(cls.BASE_URL_ENV_VAR, cls.BASE_URL_DEFAULT))
key = os.environ.get("OPENAI_API_KEY")

client = httpx.Client(
Expand Down
22 changes: 0 additions & 22 deletions packages/exchange/src/exchange/providers/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import base64
import json
import os
import re
from typing import Optional

Expand All @@ -11,27 +10,6 @@
from tenacity import retry_if_exception


def get_env_url(key: str, default: str = "") -> httpx.URL:
"""
Returns a valid 'http' or 'https' URL.
:param key: The environment key
:param default: The URL default value
:raises ValueError: If the URL scheme is not 'http' or 'https'
"""

val = os.environ.get(key, default)
if val == "":
raise ValueError(f"{key} was empty")

url = httpx.URL(val)

if url.scheme not in ["http", "https"]:
raise ValueError(f"expected {key} to be a 'http' or 'https' url: {val}")

return url


def retry_if_status(codes: Optional[list[int]] = None, above: Optional[int] = None) -> callable:
codes = codes or []

Expand Down
101 changes: 101 additions & 0 deletions packages/exchange/tests/providers/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import pytest

from exchange.providers.base import MissingProviderEnvVariableError, Provider


def test_missing_provider_env_variable_error_without_instructions_url():
env_variable = "API_KEY"
provider = "TestProvider"
error = MissingProviderEnvVariableError(env_variable, provider)

assert error.env_variable == env_variable
assert error.provider == provider
assert error.instructions_url is None
assert error.message == "Missing environment variables: API_KEY for provider TestProvider."


def test_missing_provider_env_variable_error_with_instructions_url():
env_variable = "API_KEY"
provider = "TestProvider"
instructions_url = "http://example.com/instructions"
error = MissingProviderEnvVariableError(env_variable, provider, instructions_url)

assert error.env_variable == env_variable
assert error.provider == provider
assert error.instructions_url == instructions_url
assert error.message == (
"Missing environment variables: API_KEY for provider TestProvider.\n"
"Please see http://example.com/instructions for instructions"
)


class TestProvider(Provider):
PROVIDER_NAME = "test_provider"
REQUIRED_ENV_VARS = []

def complete(self, model, system, messages, tools, **kwargs):
pass


class TestProviderBaseURL(Provider):
PROVIDER_NAME = "test_provider_base_url"
BASE_URL_ENV_VAR = "TEST_PROVIDER_BASE_URL"
REQUIRED_ENV_VARS = []

def complete(self, model, system, messages, tools, **kwargs):
pass


class TestProviderBaseURLDefault(Provider):
PROVIDER_NAME = "test_provider_base_url_default"
BASE_URL_ENV_VAR = "TEST_PROVIDER_BASE_URL_DEFAULT"
BASE_URL_DEFAULT = "http://localhost:11434/"
REQUIRED_ENV_VARS = []

def complete(self, model, system, messages, tools, **kwargs):
pass


def test_check_env_vars_no_base_url():
TestProvider.check_env_vars()


def test_check_env_vars_base_url_valid_http(monkeypatch):
monkeypatch.setenv(TestProviderBaseURL.BASE_URL_ENV_VAR, "http://localhost:11434/")

TestProviderBaseURL.check_env_vars()


def test_check_env_vars_base_url_valid_https(monkeypatch):
monkeypatch.setenv(TestProviderBaseURL.BASE_URL_ENV_VAR, "https://localhost:11434/v1")

TestProviderBaseURL.check_env_vars()


def test_check_env_vars_base_url_default():
TestProviderBaseURLDefault.check_env_vars()


def test_check_env_vars_base_url_throw_error_when_empty(monkeypatch):
monkeypatch.setenv(TestProviderBaseURL.BASE_URL_ENV_VAR, "")

with pytest.raises(KeyError, match="TEST_PROVIDER_BASE_URL"):
TestProviderBaseURL.check_env_vars()


def test_check_env_vars_base_url_throw_error_when_missing_schemes(monkeypatch):
monkeypatch.setenv(TestProviderBaseURL.BASE_URL_ENV_VAR, "localhost:11434")

with pytest.raises(
ValueError, match="Expected TEST_PROVIDER_BASE_URL to be a 'http' or 'https' url: localhost:11434"
):
TestProviderBaseURL.check_env_vars()


def test_check_env_vars_base_url_throw_error_when_invalid_scheme(monkeypatch):
monkeypatch.setenv(TestProviderBaseURL.BASE_URL_ENV_VAR, "ftp://localhost:11434/v1")

with pytest.raises(
ValueError, match="Expected TEST_PROVIDER_BASE_URL to be a 'http' or 'https' url: ftp://localhost:11434/v1"
):
TestProviderBaseURL.check_env_vars()
Loading

0 comments on commit 224d729

Please sign in to comment.