Skip to content

Commit

Permalink
add first-class support for Azure OpenAI
Browse files Browse the repository at this point in the history
  • Loading branch information
senko committed Jun 5, 2024
1 parent f910aa2 commit 070f894
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 5 deletions.
12 changes: 11 additions & 1 deletion core/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
from os.path import abspath, dirname, isdir, join
from typing import Literal, Optional, Union
from typing import Any, Literal, Optional, Union

from pydantic import BaseModel, ConfigDict, Field, field_validator
from typing_extensions import Annotated
Expand Down Expand Up @@ -55,6 +55,7 @@ class LLMProvider(str, Enum):
ANTHROPIC = "anthropic"
GROQ = "groq"
LM_STUDIO = "lm-studio"
AZURE = "azure"


class UIAdapter(str, Enum):
Expand Down Expand Up @@ -89,6 +90,10 @@ class ProviderConfig(_StrictModel):
description="Timeout (in seconds) for receiving a new chunk of data from the response stream",
ge=0.0,
)
extra: Optional[dict[str, Any]] = Field(
None,
description="Extra provider-specific configuration",
)


class AgentLLMConfig(_StrictModel):
Expand Down Expand Up @@ -140,6 +145,10 @@ class LLMConfig(_StrictModel):
description="Timeout (in seconds) for receiving a new chunk of data from the response stream",
ge=0.0,
)
extra: Optional[dict[str, Any]] = Field(
None,
description="Extra provider-specific configuration",
)

@classmethod
def from_provider_and_agent_configs(cls, provider: ProviderConfig, agent: AgentLLMConfig):
Expand All @@ -151,6 +160,7 @@ def from_provider_and_agent_configs(cls, provider: ProviderConfig, agent: AgentL
temperature=agent.temperature,
connect_timeout=provider.connect_timeout,
read_timeout=provider.read_timeout,
extra=provider.extra,
)


Expand Down
29 changes: 29 additions & 0 deletions core/llm/azure_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from httpx import Timeout
from openai import AsyncAzureOpenAI

from core.config import LLMProvider
from core.llm.openai_client import OpenAIClient
from core.log import get_logger

log = get_logger(__name__)


class AzureClient(OpenAIClient):
provider = LLMProvider.AZURE
stream_options = None

def _init_client(self):
azure_deployment = self.config.extra.get("azure_deployment")
api_version = self.config.extra.get("api_version")

self.client = AsyncAzureOpenAI(
api_key=self.config.api_key,
azure_endpoint=self.config.base_url,
azure_deployment=azure_deployment,
api_version=api_version,
timeout=Timeout(
max(self.config.connect_timeout, self.config.read_timeout),
connect=self.config.connect_timeout,
read=self.config.read_timeout,
),
)
3 changes: 3 additions & 0 deletions core/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def for_provider(provider: LLMProvider) -> type["BaseLLMClient"]:
:return: Client class for the specified provider.
"""
from .anthropic_client import AnthropicClient
from .azure_client import AzureClient
from .groq_client import GroqClient
from .openai_client import OpenAIClient

Expand All @@ -325,6 +326,8 @@ def for_provider(provider: LLMProvider) -> type["BaseLLMClient"]:
return AnthropicClient
elif provider == LLMProvider.GROQ:
return GroqClient
elif provider == LLMProvider.AZURE:
return AzureClient
else:
raise ValueError(f"Unsupported LLM provider: {provider.value}")

Expand Down
7 changes: 4 additions & 3 deletions core/llm/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

class OpenAIClient(BaseLLMClient):
provider = LLMProvider.OPENAI
stream_options = {"include_usage": True}

def _init_client(self):
self.client = AsyncOpenAI(
Expand All @@ -40,10 +41,10 @@ async def _make_request(
"messages": convo.messages,
"temperature": self.config.temperature if temperature is None else temperature,
"stream": True,
"stream_options": {
"include_usage": True,
},
}
if self.stream_options:
completion_kwargs["stream_options"] = self.stream_options

if json_mode:
completion_kwargs["response_format"] = {"type": "json_object"}

Expand Down
13 changes: 12 additions & 1 deletion example-config.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
// Configuration for the LLM providers that can be used. Pythagora supports
// OpenAI, Anthropic and Groq. Azure and OpenRouter and local LLMs (such as LM-Studio)
// OpenAI, Azure, Anthropic and Groq. OpenRouter and local LLMs (such as LM-Studio)
// also work, you can use "openai" provider to define these.
"llm": {
"openai": {
Expand All @@ -9,6 +9,17 @@
"api_key": null,
"connect_timeout": 60.0,
"read_timeout": 10.0
},
// Example config for Azure OpenAI (see https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions)
"azure": {
"base_url": "https://your-resource-name.openai.azure.com/",
"api_key": "your-api-key",
"connect_timeout": 60.0,
"read_timeout": 10.0,
"extra": {
"azure_deployment": "your-azure-deployment-id",
"api_version": "2024-02-01"
}
}
},
// Each agent can use a different model or configuration. The default, as before, is GPT4 Turbo
Expand Down

0 comments on commit 070f894

Please sign in to comment.