Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: validate provider urls before use #147

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions packages/exchange/src/exchange/providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from exchange.providers.utils import retry_if_status, raise_for_status
from exchange.langfuse_wrapper import observe_wrapper

ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages"

retry_procedure = retry(
wait=wait_fixed(2),
stop=stop_after_attempt(2),
Expand All @@ -23,6 +21,8 @@ class AnthropicProvider(Provider):
"""Provides chat completions for models hosted directly by Anthropic."""

PROVIDER_NAME = "anthropic"
BASE_URL_ENV_VAR = "ANTHROPIC_HOST"
BASE_URL_DEFAULT = "https://api.anthropic.com/v1/messages"
REQUIRED_ENV_VARS = ["ANTHROPIC_API_KEY"]

def __init__(self, client: httpx.Client) -> None:
Expand All @@ -31,7 +31,7 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: type["AnthropicProvider"]) -> "AnthropicProvider":
cls.check_env_vars()
url = os.environ.get("ANTHROPIC_HOST", ANTHROPIC_HOST)
url = httpx.URL(os.environ.get(cls.BASE_URL_ENV_VAR, cls.BASE_URL_DEFAULT))
key = os.environ.get("ANTHROPIC_API_KEY")
client = httpx.Client(
base_url=url,
Expand Down Expand Up @@ -164,5 +164,5 @@ def recommended_models() -> tuple[str, str]:

@retry_procedure
def _post(self, payload: dict) -> httpx.Response:
response = self.client.post(ANTHROPIC_HOST, json=payload)
response = self.client.post(self.BASE_URL_DEFAULT, json=payload)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is incorrect as it defers to the default value without considering an override. OTOH, if I make this empty, it fails tests...

return raise_for_status(response).json()
6 changes: 3 additions & 3 deletions packages/exchange/src/exchange/providers/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ class AzureProvider(OpenAiProvider):
"""Provides chat completions for models hosted by the Azure OpenAI Service."""

PROVIDER_NAME = "azure"
BASE_URL_ENV_VAR = "AZURE_CHAT_COMPLETIONS_HOST_NAME"
REQUIRED_ENV_VARS = [
"AZURE_CHAT_COMPLETIONS_HOST_NAME",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I separated out the base URL enforcement from the other ENV vars as there is special handling

"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION",
"AZURE_CHAT_COMPLETIONS_KEY",
Expand All @@ -21,13 +21,13 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: type["AzureProvider"]) -> "AzureProvider":
cls.check_env_vars()
url = os.environ.get("AZURE_CHAT_COMPLETIONS_HOST_NAME")
url = httpx.URL(os.environ.get(cls.BASE_URL_ENV_VAR, cls.BASE_URL_DEFAULT))
deployment_name = os.environ.get("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME")
api_version = os.environ.get("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION")
key = os.environ.get("AZURE_CHAT_COMPLETIONS_KEY")

# format the url host/"openai/deployments/" + deployment_name + "/?api-version=" + api_version
url = f"{url}/openai/deployments/{deployment_name}/"
url = url.join(f"/openai/deployments/{deployment_name}/")
client = httpx.Client(
base_url=url,
headers={"api-key": key, "Content-Type": "application/json"},
Expand Down
17 changes: 16 additions & 1 deletion packages/exchange/src/exchange/providers/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import httpx
import os
from abc import ABC, abstractmethod
from attrs import define, field
Expand All @@ -22,6 +23,8 @@ def __init__(self, provider_cls: str) -> None:

class Provider(ABC):
PROVIDER_NAME: str
BASE_URL_ENV_VAR: str = ""
BASE_URL_DEFAULT: str = ""
REQUIRED_ENV_VARS: list[str] = []

@classmethod
Expand All @@ -32,11 +35,23 @@ def from_env(cls: type["Provider"]) -> "Provider":

@classmethod
def check_env_vars(cls: type["Provider"], instructions_url: Optional[str] = None) -> None:
provider = cls.PROVIDER_NAME
missing_vars = [x for x in cls.REQUIRED_ENV_VARS if x not in os.environ]

url_var = cls.BASE_URL_ENV_VAR
if url_var:
val = os.environ.get(url_var, cls.BASE_URL_DEFAULT)
if not val:
raise KeyError(url_var)
else:
url = httpx.URL(val)

if url.scheme not in ["http", "https"]:
raise ValueError(f"Expected {url_var} to be a 'http' or 'https' url: {val}")

if missing_vars:
env_vars = ", ".join(missing_vars)
raise MissingProviderEnvVariableError(env_vars, cls.PROVIDER_NAME, instructions_url)
raise MissingProviderEnvVariableError(env_vars, provider, instructions_url)

@abstractmethod
def complete(
Expand Down
8 changes: 3 additions & 5 deletions packages/exchange/src/exchange/providers/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@ class DatabricksProvider(Provider):
"""

PROVIDER_NAME = "databricks"
REQUIRED_ENV_VARS = [
"DATABRICKS_HOST",
"DATABRICKS_TOKEN",
]
BASE_URL_ENV_VAR = "DATABRICKS_HOST"
REQUIRED_ENV_VARS = ["DATABRICKS_TOKEN"]
instructions_url = "https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields"

def __init__(self, client: httpx.Client) -> None:
Expand All @@ -43,7 +41,7 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: type["DatabricksProvider"]) -> "DatabricksProvider":
cls.check_env_vars(cls.instructions_url)
url = os.environ.get("DATABRICKS_HOST")
url = httpx.URL(os.environ.get(cls.BASE_URL_ENV_VAR))
key = os.environ.get("DATABRICKS_TOKEN")
client = httpx.Client(
base_url=url,
Expand Down
7 changes: 3 additions & 4 deletions packages/exchange/src/exchange/providers/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
from exchange.providers.utils import raise_for_status, retry_if_status, encode_image
from exchange.langfuse_wrapper import observe_wrapper


GOOGLE_HOST = "https://generativelanguage.googleapis.com/v1beta"

retry_procedure = retry(
wait=wait_fixed(2),
stop=stop_after_attempt(2),
Expand All @@ -24,6 +21,8 @@ class GoogleProvider(Provider):
"""Provides chat completions for models hosted by Google, including Gemini and other experimental models."""

PROVIDER_NAME = "google"
BASE_URL_ENV_VAR = "GOOGLE_HOST"
BASE_URL_DEFAULT = "https://generativelanguage.googleapis.com/v1beta"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it make sense to have the constant at the top of the file just as people may want to quickly scan/change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, these constants are accessed as class variables, so if we move to the top they'd no longer be that. I'm not sure a hack to work around this..

REQUIRED_ENV_VARS = ["GOOGLE_API_KEY"]
instructions_url = "https://ai.google.dev/gemini-api/docs/api-key"

Expand All @@ -33,7 +32,7 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: type["GoogleProvider"]) -> "GoogleProvider":
cls.check_env_vars(cls.instructions_url)
url = os.environ.get("GOOGLE_HOST", GOOGLE_HOST)
url = httpx.URL(os.environ.get(cls.BASE_URL_ENV_VAR, cls.BASE_URL_DEFAULT))
key = os.environ.get("GOOGLE_API_KEY")
client = httpx.Client(
base_url=url,
Expand Down
6 changes: 3 additions & 3 deletions packages/exchange/src/exchange/providers/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status

GROQ_HOST = "https://api.groq.com/openai/"

retry_procedure = retry(
wait=wait_fixed(5),
stop=stop_after_attempt(5),
Expand All @@ -30,6 +28,8 @@ class GroqProvider(Provider):
"""Provides chat completions for models hosted directly by OpenAI."""

PROVIDER_NAME = "groq"
BASE_URL_ENV_VAR = "GROQ_HOST"
BASE_URL_DEFAULT = "https://api.groq.com/openai/"
REQUIRED_ENV_VARS = ["GROQ_API_KEY"]
instructions_url = "https://console.groq.com/docs/quickstart"

Expand All @@ -39,7 +39,7 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: type["GroqProvider"]) -> "GroqProvider":
cls.check_env_vars(cls.instructions_url)
url = os.environ.get("GROQ_HOST", GROQ_HOST)
url = httpx.URL(os.environ.get(cls.BASE_URL_ENV_VAR, cls.BASE_URL_DEFAULT))
key = os.environ.get("GROQ_API_KEY")

client = httpx.Client(
Expand Down
13 changes: 8 additions & 5 deletions packages/exchange/src/exchange/providers/ollama.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os

import httpx

from typing import Type
from exchange.providers.openai import OpenAiProvider

OLLAMA_HOST = "http://localhost:11434/"
OLLAMA_MODEL = "qwen2.5"


Expand All @@ -25,14 +24,18 @@ class OllamaProvider(OpenAiProvider):
requires: {{}}
"""
PROVIDER_NAME = "ollama"
BASE_URL_ENV_VAR = "OLLAMA_HOST"
BASE_URL_DEFAULT = "http://localhost:11434/"
REQUIRED_ENV_VARS = []

def __init__(self, client: httpx.Client) -> None:
print("PLEASE NOTE: the ollama provider is experimental, use with care")
super().__init__(client)

@classmethod
def from_env(cls: type["OllamaProvider"]) -> "OllamaProvider":
ollama_url = os.environ.get("OLLAMA_HOST", OLLAMA_HOST)
def from_env(cls: Type["OllamaProvider"]) -> "OllamaProvider":
cls.check_env_vars(cls.instructions_url)
ollama_url = httpx.URL(os.environ.get(cls.BASE_URL_ENV_VAR, cls.BASE_URL_DEFAULT))
timeout = httpx.Timeout(60 * 10)

# from_env is expected to fail if required ENV variables are not
Expand All @@ -41,7 +44,7 @@ def from_env(cls: type["OllamaProvider"]) -> "OllamaProvider":
httpx.get(ollama_url, timeout=timeout)

# When served by Ollama, the OpenAI API is available at the path "v1/".
client = httpx.Client(base_url=ollama_url + "v1/", timeout=timeout)
client = httpx.Client(base_url=ollama_url.join("v1/"), timeout=timeout)
return cls(client)

@staticmethod
Expand Down
7 changes: 4 additions & 3 deletions packages/exchange/src/exchange/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from exchange.providers.utils import retry_if_status
from exchange.langfuse_wrapper import observe_wrapper

OPENAI_HOST = "https://api.openai.com/"

retry_procedure = retry(
wait=wait_fixed(2),
Expand All @@ -30,6 +29,8 @@ class OpenAiProvider(Provider):
"""Provides chat completions for models hosted directly by OpenAI."""

PROVIDER_NAME = "openai"
BASE_URL_ENV_VAR = "OPENAI_HOST"
BASE_URL_DEFAULT = "https://api.openai.com/"
REQUIRED_ENV_VARS = ["OPENAI_API_KEY"]
instructions_url = "https://platform.openai.com/docs/api-reference/api-keys"

Expand All @@ -39,11 +40,11 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: type["OpenAiProvider"]) -> "OpenAiProvider":
cls.check_env_vars(cls.instructions_url)
url = os.environ.get("OPENAI_HOST", OPENAI_HOST)
url = httpx.URL(os.environ.get(cls.BASE_URL_ENV_VAR, cls.BASE_URL_DEFAULT))
key = os.environ.get("OPENAI_API_KEY")

client = httpx.Client(
base_url=url + "v1/",
base_url=url.join("v1/"),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi using httpx.URL as it isn't so sensitive about trailing slash etc when joining

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting, i always remembered // would resolve? just curious what bug you can into since i haven't seen this before!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it was more about not having a slash at all. e.g if you set a base URL ending in the port

auth=("Bearer", key),
timeout=httpx.Timeout(60 * 10),
)
Expand Down
8 changes: 8 additions & 0 deletions packages/exchange/tests/providers/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ def anthropic_provider():
return AnthropicProvider.from_env()


def test_from_env_throw_error_when_invalid_host(monkeypatch):
monkeypatch.setenv("ANTHROPIC_HOST", "localhost:1234")
monkeypatch.setenv("ANTHROPIC_API_KEY", "test_api_key")

with pytest.raises(ValueError, match="Expected ANTHROPIC_HOST to be a 'http' or 'https' url: localhost:1234"):
AnthropicProvider.from_env()


def test_from_env_throw_error_when_missing_api_key():
with patch.dict(os.environ, {}, clear=True):
with pytest.raises(MissingProviderEnvVariableError) as context:
Expand Down
14 changes: 12 additions & 2 deletions packages/exchange/tests/providers/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,21 @@
AZURE_MODEL = os.getenv("AZURE_MODEL", "gpt-4o-mini")


def test_from_env_throw_error_when_invalid_host(monkeypatch):
monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_HOST_NAME", "localhost:1234")
monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME", "test_deployment_name")
monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION", "test_api_version")
monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_KEY", "test_api_key")

with pytest.raises(
ValueError, match="Expected AZURE_CHAT_COMPLETIONS_HOST_NAME to be a 'http' or 'https' url: localhost:1234"
):
AzureProvider.from_env()


@pytest.mark.parametrize(
"env_var_name",
[
"AZURE_CHAT_COMPLETIONS_HOST_NAME",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION",
"AZURE_CHAT_COMPLETIONS_KEY",
Expand All @@ -24,7 +35,6 @@ def test_from_env_throw_error_when_missing_env_var(env_var_name):
with patch.dict(
os.environ,
{
"AZURE_CHAT_COMPLETIONS_HOST_NAME": "test_host_name",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME": "test_deployment_name",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION": "test_api_version",
"AZURE_CHAT_COMPLETIONS_KEY": "test_api_key",
Expand Down
101 changes: 101 additions & 0 deletions packages/exchange/tests/providers/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import pytest

from exchange.providers.base import MissingProviderEnvVariableError, Provider


def test_missing_provider_env_variable_error_without_instructions_url():
env_variable = "API_KEY"
provider = "TestProvider"
error = MissingProviderEnvVariableError(env_variable, provider)

assert error.env_variable == env_variable
assert error.provider == provider
assert error.instructions_url is None
assert error.message == "Missing environment variables: API_KEY for provider TestProvider."


def test_missing_provider_env_variable_error_with_instructions_url():
env_variable = "API_KEY"
provider = "TestProvider"
instructions_url = "http://example.com/instructions"
error = MissingProviderEnvVariableError(env_variable, provider, instructions_url)

assert error.env_variable == env_variable
assert error.provider == provider
assert error.instructions_url == instructions_url
assert error.message == (
"Missing environment variables: API_KEY for provider TestProvider.\n"
"Please see http://example.com/instructions for instructions"
)


class TestProvider(Provider):
PROVIDER_NAME = "test_provider"
REQUIRED_ENV_VARS = []

def complete(self, model, system, messages, tools, **kwargs):
pass


class TestProviderBaseURL(Provider):
PROVIDER_NAME = "test_provider_base_url"
BASE_URL_ENV_VAR = "TEST_PROVIDER_BASE_URL"
REQUIRED_ENV_VARS = []

def complete(self, model, system, messages, tools, **kwargs):
pass


class TestProviderBaseURLDefault(Provider):
PROVIDER_NAME = "test_provider_base_url_default"
BASE_URL_ENV_VAR = "TEST_PROVIDER_BASE_URL_DEFAULT"
BASE_URL_DEFAULT = "http://localhost:11434/"
REQUIRED_ENV_VARS = []

def complete(self, model, system, messages, tools, **kwargs):
pass


def test_check_env_vars_no_base_url():
TestProvider.check_env_vars()


def test_check_env_vars_base_url_valid_http(monkeypatch):
monkeypatch.setenv(TestProviderBaseURL.BASE_URL_ENV_VAR, "http://localhost:11434/")

TestProviderBaseURL.check_env_vars()


def test_check_env_vars_base_url_valid_https(monkeypatch):
monkeypatch.setenv(TestProviderBaseURL.BASE_URL_ENV_VAR, "https://localhost:11434/v1")

TestProviderBaseURL.check_env_vars()


def test_check_env_vars_base_url_default():
TestProviderBaseURLDefault.check_env_vars()


def test_check_env_vars_base_url_throw_error_when_empty(monkeypatch):
monkeypatch.setenv(TestProviderBaseURL.BASE_URL_ENV_VAR, "")

with pytest.raises(KeyError, match="TEST_PROVIDER_BASE_URL"):
TestProviderBaseURL.check_env_vars()


def test_check_env_vars_base_url_throw_error_when_missing_schemes(monkeypatch):
monkeypatch.setenv(TestProviderBaseURL.BASE_URL_ENV_VAR, "localhost:11434")

with pytest.raises(
ValueError, match="Expected TEST_PROVIDER_BASE_URL to be a 'http' or 'https' url: localhost:11434"
):
TestProviderBaseURL.check_env_vars()


def test_check_env_vars_base_url_throw_error_when_invalid_scheme(monkeypatch):
monkeypatch.setenv(TestProviderBaseURL.BASE_URL_ENV_VAR, "ftp://localhost:11434/v1")

with pytest.raises(
ValueError, match="Expected TEST_PROVIDER_BASE_URL to be a 'http' or 'https' url: ftp://localhost:11434/v1"
):
TestProviderBaseURL.check_env_vars()
Loading