Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for extra_body parameter for embeddings API #906

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
14 changes: 14 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,20 @@ func withBody(body any) requestOption {
}
}

func withExtraBody(extraBody map[string]any) requestOption {
return func(args *requestOptions) {
// Assert that args.body is a map[string]any.
bodyMap, ok := args.body.(map[string]any)
if ok {
// If it's a map[string]any then only add extraBody
// fields to args.body otherwise keep only fields in request struct.
for key, value := range extraBody {
bodyMap[key] = value
}
}
}
}

func withContentType(contentType string) requestOption {
return func(args *requestOptions) {
args.header.Set("Content-Type", contentType)
Expand Down
32 changes: 31 additions & 1 deletion embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/base64"
"encoding/binary"
"encoding/json"
"errors"
"math"
"net/http"
Expand Down Expand Up @@ -160,6 +161,9 @@ type EmbeddingRequest struct {
// Dimensions The number of dimensions the resulting output embeddings should have.
// Only supported in text-embedding-3 and later models.
Dimensions int `json:"dimensions,omitempty"`
// The ExtraBody field allows for the inclusion of arbitrary key-value pairs
// in the request body that may not be explicitly defined in this struct.
ExtraBody map[string]any `json:"extra_body,omitempty"`
}

func (r EmbeddingRequest) Convert() EmbeddingRequest {
Expand Down Expand Up @@ -187,6 +191,9 @@ type EmbeddingRequestStrings struct {
// Dimensions The number of dimensions the resulting output embeddings should have.
// Only supported in text-embedding-3 and later models.
Dimensions int `json:"dimensions,omitempty"`
// The ExtraBody field allows for the inclusion of arbitrary key-value pairs
// in the request body that may not be explicitly defined in this struct.
ExtraBody map[string]any `json:"extra_body,omitempty"`
}

func (r EmbeddingRequestStrings) Convert() EmbeddingRequest {
Expand All @@ -196,6 +203,7 @@ func (r EmbeddingRequestStrings) Convert() EmbeddingRequest {
User: r.User,
EncodingFormat: r.EncodingFormat,
Dimensions: r.Dimensions,
ExtraBody: r.ExtraBody,
}
}

Expand All @@ -219,6 +227,9 @@ type EmbeddingRequestTokens struct {
// Dimensions The number of dimensions the resulting output embeddings should have.
// Only supported in text-embedding-3 and later models.
Dimensions int `json:"dimensions,omitempty"`
// The ExtraBody field allows for the inclusion of arbitrary key-value pairs
// in the request body that may not be explicitly defined in this struct.
ExtraBody map[string]any `json:"extra_body,omitempty"`
}

func (r EmbeddingRequestTokens) Convert() EmbeddingRequest {
Expand All @@ -228,6 +239,7 @@ func (r EmbeddingRequestTokens) Convert() EmbeddingRequest {
User: r.User,
EncodingFormat: r.EncodingFormat,
Dimensions: r.Dimensions,
ExtraBody: r.ExtraBody,
}
}

Expand All @@ -241,11 +253,29 @@ func (c *Client) CreateEmbeddings(
conv EmbeddingRequestConverter,
) (res EmbeddingResponse, err error) {
baseReq := conv.Convert()

// The body map is used to dynamically construct the request payload for the embedding API.
// Instead of relying on a fixed struct, the body map allows for flexible inclusion of fields
// based on their presence, avoiding unnecessary or empty fields in the request.
extraBody := baseReq.ExtraBody
baseReq.ExtraBody = nil

// Serialize baseReq to JSON
jsonData, err := json.Marshal(baseReq)
if err != nil {
return
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR! Do I understand correctly that we'll need to extend logic here every time EmbeddingRequest gets updated?

Copy link
Owner

@sashabaranov sashabaranov Nov 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, could we maybe serialize baseReq to json and then back to map[string]any? That would allow us not to keep this logic updated

Copy link
Contributor Author

@AyushSawant18588 AyushSawant18588 Nov 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes will need to extend logic here also if EmbeddingRequest gets updated. But yeah I think to support extra_body param this map approach is the way to do so in Golang.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, could we maybe serialize baseReq to json and then back to map[string]any? That would allow us not to keep this logic updated

Okay will look into it and update the PR

Copy link
Contributor Author

@AyushSawant18588 AyushSawant18588 Dec 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sashabaranov I have updated the PR and tested the changes, can you please review the changes?


// Deserialize JSON to map[string]any
var body map[string]any
_ = json.Unmarshal(jsonData, &body)

req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL("/embeddings", withModel(string(baseReq.Model))),
withBody(baseReq),
withBody(body), // Main request body.
withExtraBody(extraBody), // Merge ExtraBody fields.
)
if err != nil {
return
Expand Down
46 changes: 45 additions & 1 deletion embeddings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,24 @@ func TestEmbedding(t *testing.T) {
t.Fatalf("Expected embedding request to contain model field")
}

// test embedding request with strings and extra_body param
embeddingReqWithExtraBody := openai.EmbeddingRequest{
Input: []string{
"The food was delicious and the waiter",
"Other examples of embedding request",
},
Model: model,
ExtraBody: map[string]any{
"input_type": "query",
"truncate": "NONE",
},
}
marshaled, err = json.Marshal(embeddingReqWithExtraBody)
checks.NoError(t, err, "Could not marshal embedding request")
if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
t.Fatalf("Expected embedding request to contain model field")
}

// test embedding request with strings
embeddingReqStrings := openai.EmbeddingRequestStrings{
Input: []string{
Expand Down Expand Up @@ -124,7 +142,33 @@ func TestEmbeddingEndpoint(t *testing.T) {
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
}

// test create embeddings with strings (simple embedding request)
// test create embeddings with strings (ExtraBody in request)
res, err = client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
ExtraBody: map[string]any{
"input_type": "query",
"truncate": "NONE",
},
Dimensions: 1,
},
)
checks.NoError(t, err, "CreateEmbeddings error")
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
}

// test create embeddings with strings (ExtraBody in request and )
_, err = client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Input: make(chan int), // Channels are not serializable
Model: "example_model",
},
)
checks.HasError(t, err, "CreateEmbeddings error")

// test failed (Serialize JSON error)
res, err = client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Expand Down
Loading