Skip to content

Commit

Permalink
Update all Azure components (#3475)
Browse files Browse the repository at this point in the history
Signed-off-by: Bernd Verst <github@bernd.dev>
  • Loading branch information
berndverst committed Jul 3, 2024
1 parent 273bea1 commit 2e35e1f
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 134 deletions.
78 changes: 50 additions & 28 deletions bindings/azure/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ package openai
import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"

"github.com/dapr/components-contrib/bindings"
azauth "github.com/dapr/components-contrib/common/authentication/azure"
Expand Down Expand Up @@ -120,10 +121,10 @@ func (p *AzOpenAI) Init(ctx context.Context, meta bindings.Metadata) error {

if m.APIKey != "" {
// use API key authentication
var keyCredential azopenai.KeyCredential
keyCredential, err = azopenai.NewKeyCredential(m.APIKey)
if err != nil {
return fmt.Errorf("error getting credentials object: %w", err)
var keyCredential *azcore.KeyCredential
keyCredential = azcore.NewKeyCredential(m.APIKey)
if keyCredential == nil {
return errors.New("error getting credentials object")
}

p.client, err = azopenai.NewClientWithKeyCredential(m.Endpoint, keyCredential, nil)
Expand Down Expand Up @@ -163,7 +164,7 @@ func (p *AzOpenAI) Operations() []bindings.OperationKind {
// Invoke handles all invoke operations.
func (p *AzOpenAI) Invoke(ctx context.Context, req *bindings.InvokeRequest) (resp *bindings.InvokeResponse, err error) {
if req == nil || len(req.Metadata) == 0 {
return nil, fmt.Errorf("invalid request: metadata is required")
return nil, errors.New("invalid request: metadata is required")
}

startTime := time.Now().UTC()
Expand Down Expand Up @@ -228,7 +229,7 @@ func (p *AzOpenAI) completion(ctx context.Context, message []byte, metadata map[
}

if prompt.Prompt == "" {
return nil, fmt.Errorf("prompt is required for completion operation")
return nil, errors.New("prompt is required for completion operation")
}

if prompt.DeploymentID == "" {
Expand All @@ -240,13 +241,13 @@ func (p *AzOpenAI) completion(ctx context.Context, message []byte, metadata map[
}

resp, err := p.client.GetCompletions(ctx, azopenai.CompletionsOptions{
Deployment: prompt.DeploymentID,
Prompt: []string{prompt.Prompt},
MaxTokens: &prompt.MaxTokens,
Temperature: &prompt.Temperature,
TopP: &prompt.TopP,
N: &prompt.N,
Stop: prompt.Stop,
DeploymentName: &prompt.DeploymentID,
Prompt: []string{prompt.Prompt},
MaxTokens: &prompt.MaxTokens,
Temperature: &prompt.Temperature,
TopP: &prompt.TopP,
N: &prompt.N,
Stop: prompt.Stop,
}, nil)
if err != nil {
return nil, fmt.Errorf("error getting completion api: %w", err)
Expand Down Expand Up @@ -280,7 +281,7 @@ func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, me
}

if len(messages.Messages) == 0 {
return nil, fmt.Errorf("messages are required for chat-completion operation")
return nil, errors.New("messages are required for chat-completion operation")
}

if messages.DeploymentID == "" {
Expand All @@ -291,11 +292,32 @@ func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, me
messages.Stop = nil
}

messageReq := make([]azopenai.ChatMessage, len(messages.Messages))
messageReq := make([]azopenai.ChatRequestMessageClassification, len(messages.Messages))
for i, m := range messages.Messages {
messageReq[i] = azopenai.ChatMessage{
Role: to.Ptr(azopenai.ChatRole(m.Role)),
Content: to.Ptr(m.Message),
currentMsg := m.Message
switch azopenai.ChatRole(m.Role) {
case azopenai.ChatRoleUser:
messageReq[i] = &azopenai.ChatRequestUserMessage{
Content: azopenai.NewChatRequestUserMessageContent(currentMsg),
}
case azopenai.ChatRoleAssistant:
messageReq[i] = &azopenai.ChatRequestAssistantMessage{
Content: &currentMsg,
}
case azopenai.ChatRoleFunction:
messageReq[i] = &azopenai.ChatRequestFunctionMessage{
Content: &currentMsg,
}
case azopenai.ChatRoleSystem:
messageReq[i] = &azopenai.ChatRequestSystemMessage{
Content: &currentMsg,
}
case azopenai.ChatRoleTool:
messageReq[i] = &azopenai.ChatRequestToolMessage{
Content: &currentMsg,
}
default:
return nil, fmt.Errorf("invalid role: %s", m.Role)
}
}

Expand All @@ -305,13 +327,13 @@ func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, me
}

res, err := p.client.GetChatCompletions(ctx, azopenai.ChatCompletionsOptions{
Deployment: messages.DeploymentID,
MaxTokens: maxTokens,
Temperature: &messages.Temperature,
TopP: &messages.TopP,
N: &messages.N,
Messages: messageReq,
Stop: messages.Stop,
DeploymentName: &messages.DeploymentID,
MaxTokens: maxTokens,
Temperature: &messages.Temperature,
TopP: &messages.TopP,
N: &messages.N,
Messages: messageReq,
Stop: messages.Stop,
}, nil)
if err != nil {
return nil, fmt.Errorf("error getting chat completion api: %w", err)
Expand Down Expand Up @@ -343,8 +365,8 @@ func (p *AzOpenAI) getEmbedding(ctx context.Context, messageRequest []byte, meta
}

res, err := p.client.GetEmbeddings(ctx, azopenai.EmbeddingsOptions{
Deployment: message.DeploymentID,
Input: []string{message.Message},
DeploymentName: &message.DeploymentID,
Input: []string{message.Message},
}, nil)
if err != nil {
return nil, fmt.Errorf("error getting embedding api: %w", err)
Expand Down
32 changes: 16 additions & 16 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ require (
cloud.google.com/go/secretmanager v1.11.5
cloud.google.com/go/storage v1.36.0
dubbo.apache.org/dubbo-go/v3 v3.0.3-0.20230118042253-4f159a2b38f3
github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.3.0
github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.6.0
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.2
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0
github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig v1.1.0
github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v1.0.1
github.com/Azure/azure-sdk-for-go/sdk/data/aztables v1.1.0
github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs v1.1.0
github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus v1.7.0
github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v1.0.3
github.com/Azure/azure-sdk-for-go/sdk/data/aztables v1.2.0
github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs v1.2.1
github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus v1.7.1
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/eventgrid/armeventgrid/v2 v2.2.0
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/eventhub/armeventhub v1.2.0
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.1.0
Expand Down Expand Up @@ -51,7 +51,7 @@ require (
github.com/chebyrash/promise v0.0.0-20230709133807-42ec49ba1459
github.com/cinience/go_rocketmq v0.0.2
github.com/cloudevents/sdk-go/binding/format/protobuf/v2 v2.14.0
github.com/cloudevents/sdk-go/v2 v2.14.0
github.com/cloudevents/sdk-go/v2 v2.15.2
github.com/cloudwego/kitex v0.5.0
github.com/cloudwego/kitex-examples v0.1.1
github.com/cyphar/filepath-securejoin v0.2.4
Expand Down Expand Up @@ -117,10 +117,10 @@ require (
go.mongodb.org/mongo-driver v1.12.1
go.uber.org/multierr v1.11.0
go.uber.org/ratelimit v0.3.0
golang.org/x/crypto v0.22.0
golang.org/x/crypto v0.24.0
golang.org/x/exp v0.0.0-20240119083558-1b970713d09a
golang.org/x/mod v0.14.0
golang.org/x/net v0.24.0
golang.org/x/mod v0.17.0
golang.org/x/net v0.26.0
golang.org/x/oauth2 v0.17.0
google.golang.org/api v0.162.0
google.golang.org/grpc v1.63.0
Expand All @@ -147,7 +147,7 @@ require (
github.com/99designs/keyring v1.2.1 // indirect
github.com/AthenZ/athenz v1.10.39 // indirect
github.com/Azure/azure-sdk-for-go v68.0.0+incompatible // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.7.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.8.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.0.0 // indirect
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect
github.com/DataDog/zstd v1.5.2 // indirect
Expand Down Expand Up @@ -383,12 +383,12 @@ require (
go.uber.org/atomic v1.10.0 // indirect
go.uber.org/zap v1.24.0 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/sync v0.6.0 // indirect
golang.org/x/sys v0.19.0 // indirect
golang.org/x/term v0.19.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.21.0 // indirect
golang.org/x/term v0.21.0 // indirect
golang.org/x/text v0.16.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.17.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
google.golang.org/appengine v1.6.8 // indirect
google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240227224415-6ceb2ff114de // indirect
Expand Down
Loading

0 comments on commit 2e35e1f

Please sign in to comment.