Skip to content

Commit

Permalink
feat: add azure, examples, and message constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
yjp20 committed Aug 1, 2024
1 parent c329601 commit 80d17f9
Show file tree
Hide file tree
Showing 19 changed files with 1,070 additions and 135 deletions.
59 changes: 47 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,9 @@ func main() {
option.WithAPIKey("My API Key"), // defaults to os.LookupEnv("OPENAI_API_KEY")
)
chatCompletion, err := client.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{openai.ChatCompletionUserMessageParam{
Role: openai.F(openai.ChatCompletionUserMessageParamRoleUser),
Content: openai.F[openai.ChatCompletionUserMessageParamContentUnion](shared.UnionString("Say this is a test")),
}}),
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{
openai.UserMessage("Say this is a test"),
}),
Model: openai.F(openai.ChatModelGPT4o),
})
if err != nil {
Expand Down Expand Up @@ -237,10 +236,9 @@ defer cancel()
client.Chat.Completions.New(
ctx,
openai.ChatCompletionNewParams{
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{openai.ChatCompletionUserMessageParam{
Role: openai.F(openai.ChatCompletionUserMessageParamRoleUser),
Content: openai.F[openai.ChatCompletionUserMessageParamContentUnion](shared.UnionString("How can I list all files in a directory using Python?")),
}}),
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{
openai.UserMessage("Say this is a test"),
}),
Model: openai.F(openai.ChatModelGPT4o),
},
// This sets the per-retry timeout
Expand Down Expand Up @@ -300,10 +298,9 @@ client := openai.NewClient(
client.Chat.Completions.New(
context.TODO(),
openai.ChatCompletionNewParams{
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{openai.ChatCompletionUserMessageParam{
Role: openai.F(openai.ChatCompletionUserMessageParamRoleUser),
Content: openai.F[openai.ChatCompletionUserMessageParamContentUnion](shared.UnionString("How can I get the name of the current day in Node.js?")),
}}),
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{
openai.UserMessage("Say this is a test"),
}),
Model: openai.F(openai.ChatModelGPT4o),
},
option.WithMaxRetries(5),
Expand Down Expand Up @@ -396,6 +393,44 @@ You may also replace the default `http.Client` with
accepted (this overwrites any previous client) and receives requests after any
middleware has been applied.

## Microsoft Azure OpenAI

To use this library with [Azure OpenAI](https://learn.microsoft.com/azure/ai-services/openai/overview), use the option.RequestOption functions in the `azure` package.

```go
package main

import (
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/openai/openai-go"
"github.com/openai/openai-go/azure"
"github.com/openai/openai-go/option"
)

func main() {
const azureOpenAIEndpoint = "https://<azure-openai-resource>.openai.azure.com"

// The latest API versions, including previews, can be found here:
// https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning
const azureOpenAIAPIVersion = "2024-06-01"

tokenCredential, err := azidentity.NewDefaultAzureCredential(nil)

if err != nil {
fmt.Printf("Failed to create the DefaultAzureCredential: %s", err)
os.Exit(1)
}

client := openai.NewClient(
azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion),

// Choose between authenticating using a TokenCredential or an API Key
azure.WithTokenCredential(tokenCredential),
// or azure.WithAPIKey(azureOpenAIAPIKey),
)
}
```

## Semantic versioning

This package generally follows [SemVer](https://semver.org/spec/v2.0.0.html) conventions, though certain backwards-incompatible changes may be released as minor versions:
Expand Down
237 changes: 237 additions & 0 deletions azure/azure.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
// Package azure provides configuration options so you can connect and use Azure OpenAI using the [openai.Client].
//
// Typical usage of this package will look like this:
//
// client := openai.NewClient(
// azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion),
// azure.WithTokenCredential(azureIdentityTokenCredential),
// // or azure.WithAPIKey(azureOpenAIAPIKey),
// )
//
// Or, if you want to construct a specific service:
//
// client := openai.NewChatCompletionService(
// azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion),
// azure.WithTokenCredential(azureIdentityTokenCredential),
// // or azure.WithAPIKey(azureOpenAIAPIKey),
// )
package azure

import (
"bytes"
"encoding/json"
"errors"
"io"
"mime"
"mime/multipart"
"net/http"
"net/url"
"strings"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/openai/openai-go/internal/requestconfig"
"github.com/openai/openai-go/option"
)

// WithEndpoint configures this client to connect to an Azure OpenAI endpoint.
//
// - endpoint - the Azure OpenAI endpoint to connect to. Ex: https://<azure-openai-resource>.openai.azure.com
// - apiVersion - the Azure OpenAI API version to target (ex: 2024-06-01). See [Azure OpenAI apiversions] for current API versions. This value cannot be empty.
//
// This function should be paired with a call to authenticate, like [azure.WithAPIKey] or [azure.WithTokenCredential], similar to this:
//
// client := openai.NewClient(
// azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion),
// azure.WithTokenCredential(azureIdentityTokenCredential),
// // or azure.WithAPIKey(azureOpenAIAPIKey),
// )
//
// [Azure OpenAI apiversions]: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning
func WithEndpoint(endpoint string, apiVersion string) option.RequestOption {
if !strings.HasSuffix(endpoint, "/") {
endpoint += "/"
}

endpoint += "openai/"

withQueryAdd := option.WithQueryAdd("api-version", apiVersion)
withEndpoint := option.WithBaseURL(endpoint)

withModelMiddleware := option.WithMiddleware(func(r *http.Request, mn option.MiddlewareNext) (*http.Response, error) {
replacementPath, err := getReplacementPathWithDeployment(r)

if err != nil {
return nil, err
}

r.URL.Path = replacementPath
return mn(r)
})

return func(rc *requestconfig.RequestConfig) error {
if apiVersion == "" {
return errors.New("apiVersion is an empty string, but needs to be set. See https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning for details.")
}

if err := withQueryAdd(rc); err != nil {
return err
}

if err := withEndpoint(rc); err != nil {
return err
}

if err := withModelMiddleware(rc); err != nil {
return err
}

return nil
}
}

// WithTokenCredential configures this client to authenticate using an [Azure Identity] TokenCredential.
// This function should be paired with a call to [WithEndpoint] to point to your Azure OpenAI instance.
//
// [Azure Identity]: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity
func WithTokenCredential(tokenCredential azcore.TokenCredential) option.RequestOption {
bearerTokenPolicy := runtime.NewBearerTokenPolicy(tokenCredential, []string{"https://cognitiveservices.azure.com/.default"}, nil)

// add in a middleware that uses the bearer token generated from the token credential
return option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) {
pipeline := runtime.NewPipeline("azopenai-extensions", version, runtime.PipelineOptions{}, &policy.ClientOptions{
InsecureAllowCredentialWithHTTP: true, // allow for plain HTTP proxies, etc..
PerRetryPolicies: []policy.Policy{
bearerTokenPolicy,
policyAdapter(next),
},
})

req2, err := runtime.NewRequestFromRequest(req)

if err != nil {
return nil, err
}

return pipeline.Do(req2)
})
}

// WithAPIKey configures this client to authenticate using an API key.
// This function should be paired with a call to [WithEndpoint] to point to your Azure OpenAI instance.
func WithAPIKey(apiKey string) option.RequestOption {
// NOTE: there is an option.WithApiKey(), but that adds the value into
// the Authorization header instead so we're doing this instead.
return option.WithHeader("Api-Key", apiKey)
}

// jsonRoutes have JSON payloads - we'll deserialize looking for a .model field in there
// so we won't have to worry about individual types for completions vs embeddings, etc...
var jsonRoutes = map[string]bool{
"/openai/completions": true,
"/openai/chat/completions": true,
"/openai/embeddings": true,
"/openai/audio/speech": true,
"/openai/images/generations": true,
}

// audioMultipartRoutes have mime/multipart payloads. These are less generic - we're very much
// expecting a transcription or translation payload for these.
var audioMultipartRoutes = map[string]bool{
"/openai/audio/transcriptions": true,
"/openai/audio/translations": true,
}

// getReplacementPathWithDeployment parses the request body to extract out the Model parameter (or equivalent)
// (note, the req.Body is fully read as part of this, and is replaced with a bytes.Reader)
func getReplacementPathWithDeployment(req *http.Request) (string, error) {
if jsonRoutes[req.URL.Path] {
return getJSONRoute(req)
}

if audioMultipartRoutes[req.URL.Path] {
return getAudioMultipartRoute(req)
}

// No need to relocate the path. We've already tacked on /openai when we setup the endpoint.
return req.URL.Path, nil
}

func getJSONRoute(req *http.Request) (string, error) {
// we need to deserialize the body, partly, in order to read out the model field.
jsonBytes, err := io.ReadAll(req.Body)

if err != nil {
return "", err
}

// make sure we restore the body so it can be used in later middlewares.
req.Body = io.NopCloser(bytes.NewReader(jsonBytes))

var v *struct {
Model string `json:"model"`
}

if err := json.Unmarshal(jsonBytes, &v); err != nil {
return "", err
}

escapedDeployment := url.PathEscape(v.Model)
return strings.Replace(req.URL.Path, "/openai/", "/openai/deployments/"+escapedDeployment+"/", 1), nil
}

func getAudioMultipartRoute(req *http.Request) (string, error) {
// body is a multipart/mime body type instead.
mimeBytes, err := io.ReadAll(req.Body)

if err != nil {
return "", err
}

// make sure we restore the body so it can be used in later middlewares.
req.Body = io.NopCloser(bytes.NewReader(mimeBytes))

_, mimeParams, err := mime.ParseMediaType(req.Header.Get("Content-Type"))

if err != nil {
return "", err
}

mimeReader := multipart.NewReader(
io.NopCloser(bytes.NewReader(mimeBytes)),
mimeParams["boundary"])

for {
mp, err := mimeReader.NextPart()

if err != nil {
if errors.Is(err, io.EOF) {
return "", errors.New("unable to find the model part in multipart body")
}

return "", err
}

defer mp.Close()

if mp.FormName() == "model" {
modelBytes, err := io.ReadAll(mp)

if err != nil {
return "", err
}

escapedDeployment := url.PathEscape(string(modelBytes))
return strings.Replace(req.URL.Path, "/openai/", "/openai/deployments/"+escapedDeployment+"/", 1), nil
}
}
}

type policyAdapter option.MiddlewareNext

func (mp policyAdapter) Do(req *policy.Request) (*http.Response, error) {
return (option.MiddlewareNext)(mp)(req.Raw())
}

const version = "v.0.1.0"
Loading

0 comments on commit 80d17f9

Please sign in to comment.