From 070f894aa965be5aff8438a6d203fe5d7bf28151 Mon Sep 17 00:00:00 2001 From: Senko Rasic Date: Wed, 5 Jun 2024 09:53:46 +0200 Subject: [PATCH] add first-class support for Azure OpenAI --- core/config/__init__.py | 12 +++++++++++- core/llm/azure_client.py | 29 +++++++++++++++++++++++++++++ core/llm/base.py | 3 +++ core/llm/openai_client.py | 7 ++++--- example-config.json | 13 ++++++++++++- 5 files changed, 59 insertions(+), 5 deletions(-) create mode 100644 core/llm/azure_client.py diff --git a/core/config/__init__.py b/core/config/__init__.py index e9bd8ccef..ba826b1ee 100644 --- a/core/config/__init__.py +++ b/core/config/__init__.py @@ -1,6 +1,6 @@ from enum import Enum from os.path import abspath, dirname, isdir, join -from typing import Literal, Optional, Union +from typing import Any, Literal, Optional, Union from pydantic import BaseModel, ConfigDict, Field, field_validator from typing_extensions import Annotated @@ -55,6 +55,7 @@ class LLMProvider(str, Enum): ANTHROPIC = "anthropic" GROQ = "groq" LM_STUDIO = "lm-studio" + AZURE = "azure" class UIAdapter(str, Enum): @@ -89,6 +90,10 @@ class ProviderConfig(_StrictModel): description="Timeout (in seconds) for receiving a new chunk of data from the response stream", ge=0.0, ) + extra: Optional[dict[str, Any]] = Field( + None, + description="Extra provider-specific configuration", + ) class AgentLLMConfig(_StrictModel): @@ -140,6 +145,10 @@ class LLMConfig(_StrictModel): description="Timeout (in seconds) for receiving a new chunk of data from the response stream", ge=0.0, ) + extra: Optional[dict[str, Any]] = Field( + None, + description="Extra provider-specific configuration", + ) @classmethod def from_provider_and_agent_configs(cls, provider: ProviderConfig, agent: AgentLLMConfig): @@ -151,6 +160,7 @@ def from_provider_and_agent_configs(cls, provider: ProviderConfig, agent: AgentL temperature=agent.temperature, connect_timeout=provider.connect_timeout, read_timeout=provider.read_timeout, + extra=provider.extra, ) diff --git a/core/llm/azure_client.py b/core/llm/azure_client.py new file mode 100644 index 000000000..300bbe450 --- /dev/null +++ b/core/llm/azure_client.py @@ -0,0 +1,29 @@ +from httpx import Timeout +from openai import AsyncAzureOpenAI + +from core.config import LLMProvider +from core.llm.openai_client import OpenAIClient +from core.log import get_logger + +log = get_logger(__name__) + + +class AzureClient(OpenAIClient): + provider = LLMProvider.AZURE + stream_options = None + + def _init_client(self): + azure_deployment = self.config.extra.get("azure_deployment") + api_version = self.config.extra.get("api_version") + + self.client = AsyncAzureOpenAI( + api_key=self.config.api_key, + azure_endpoint=self.config.base_url, + azure_deployment=azure_deployment, + api_version=api_version, + timeout=Timeout( + max(self.config.connect_timeout, self.config.read_timeout), + connect=self.config.connect_timeout, + read=self.config.read_timeout, + ), + ) diff --git a/core/llm/base.py b/core/llm/base.py index 7881dae40..afe7f3219 100644 --- a/core/llm/base.py +++ b/core/llm/base.py @@ -316,6 +316,7 @@ def for_provider(provider: LLMProvider) -> type["BaseLLMClient"]: :return: Client class for the specified provider. """ from .anthropic_client import AnthropicClient + from .azure_client import AzureClient from .groq_client import GroqClient from .openai_client import OpenAIClient @@ -325,6 +326,8 @@ def for_provider(provider: LLMProvider) -> type["BaseLLMClient"]: return AnthropicClient elif provider == LLMProvider.GROQ: return GroqClient + elif provider == LLMProvider.AZURE: + return AzureClient else: raise ValueError(f"Unsupported LLM provider: {provider.value}") diff --git a/core/llm/openai_client.py b/core/llm/openai_client.py index feddaf581..1933ab54a 100644 --- a/core/llm/openai_client.py +++ b/core/llm/openai_client.py @@ -17,6 +17,7 @@ class OpenAIClient(BaseLLMClient): provider = LLMProvider.OPENAI + stream_options = {"include_usage": True} def _init_client(self): self.client = AsyncOpenAI( @@ -40,10 +41,10 @@ async def _make_request( "messages": convo.messages, "temperature": self.config.temperature if temperature is None else temperature, "stream": True, - "stream_options": { - "include_usage": True, - }, } + if self.stream_options: + completion_kwargs["stream_options"] = self.stream_options + if json_mode: completion_kwargs["response_format"] = {"type": "json_object"} diff --git a/example-config.json b/example-config.json index b923881bf..e34fb4c3d 100644 --- a/example-config.json +++ b/example-config.json @@ -1,6 +1,6 @@ { // Configuration for the LLM providers that can be used. Pythagora supports - // OpenAI, Anthropic and Groq. Azure and OpenRouter and local LLMs (such as LM-Studio) + // OpenAI, Azure, Anthropic and Groq. OpenRouter and local LLMs (such as LM-Studio) // also work, you can use "openai" provider to define these. "llm": { "openai": { @@ -9,6 +9,17 @@ "api_key": null, "connect_timeout": 60.0, "read_timeout": 10.0 + }, + // Example config for Azure OpenAI (see https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions) + "azure": { + "base_url": "https://your-resource-name.openai.azure.com/", + "api_key": "your-api-key", + "connect_timeout": 60.0, + "read_timeout": 10.0, + "extra": { + "azure_deployment": "your-azure-deployment-id", + "api_version": "2024-02-01" + } } }, // Each agent can use a different model or configuration. The default, as before, is GPT4 Turbo