diff --git a/client.go b/client.go index ed8595e0..4782726c 100644 --- a/client.go +++ b/client.go @@ -84,6 +84,20 @@ func withBody(body any) requestOption { } } +func withExtraBody(extraBody map[string]any) requestOption { + return func(args *requestOptions) { + // Assert that args.body is a map[string]any. + bodyMap, ok := args.body.(map[string]any) + if ok { + // If it's a map[string]any then only add extraBody + // fields to args.body otherwise keep only fields in request struct. + for key, value := range extraBody { + bodyMap[key] = value + } + } + } +} + func withContentType(contentType string) requestOption { return func(args *requestOptions) { args.header.Set("Content-Type", contentType) diff --git a/embeddings.go b/embeddings.go index 4a0e682d..8593f8b5 100644 --- a/embeddings.go +++ b/embeddings.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" "encoding/binary" + "encoding/json" "errors" "math" "net/http" @@ -160,6 +161,9 @@ type EmbeddingRequest struct { // Dimensions The number of dimensions the resulting output embeddings should have. // Only supported in text-embedding-3 and later models. Dimensions int `json:"dimensions,omitempty"` + // The ExtraBody field allows for the inclusion of arbitrary key-value pairs + // in the request body that may not be explicitly defined in this struct. + ExtraBody map[string]any `json:"extra_body,omitempty"` } func (r EmbeddingRequest) Convert() EmbeddingRequest { @@ -187,6 +191,9 @@ type EmbeddingRequestStrings struct { // Dimensions The number of dimensions the resulting output embeddings should have. // Only supported in text-embedding-3 and later models. Dimensions int `json:"dimensions,omitempty"` + // The ExtraBody field allows for the inclusion of arbitrary key-value pairs + // in the request body that may not be explicitly defined in this struct. + ExtraBody map[string]any `json:"extra_body,omitempty"` } func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { @@ -196,6 +203,7 @@ func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { User: r.User, EncodingFormat: r.EncodingFormat, Dimensions: r.Dimensions, + ExtraBody: r.ExtraBody, } } @@ -219,6 +227,9 @@ type EmbeddingRequestTokens struct { // Dimensions The number of dimensions the resulting output embeddings should have. // Only supported in text-embedding-3 and later models. Dimensions int `json:"dimensions,omitempty"` + // The ExtraBody field allows for the inclusion of arbitrary key-value pairs + // in the request body that may not be explicitly defined in this struct. + ExtraBody map[string]any `json:"extra_body,omitempty"` } func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { @@ -228,6 +239,7 @@ func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { User: r.User, EncodingFormat: r.EncodingFormat, Dimensions: r.Dimensions, + ExtraBody: r.ExtraBody, } } @@ -241,11 +253,29 @@ func (c *Client) CreateEmbeddings( conv EmbeddingRequestConverter, ) (res EmbeddingResponse, err error) { baseReq := conv.Convert() + + // The body map is used to dynamically construct the request payload for the embedding API. + // Instead of relying on a fixed struct, the body map allows for flexible inclusion of fields + // based on their presence, avoiding unnecessary or empty fields in the request. + extraBody := baseReq.ExtraBody + baseReq.ExtraBody = nil + + // Serialize baseReq to JSON + jsonData, err := json.Marshal(baseReq) + if err != nil { + return + } + + // Deserialize JSON to map[string]any + var body map[string]any + _ = json.Unmarshal(jsonData, &body) + req, err := c.newRequest( ctx, http.MethodPost, c.fullURL("/embeddings", withModel(string(baseReq.Model))), - withBody(baseReq), + withBody(body), // Main request body. + withExtraBody(extraBody), // Merge ExtraBody fields. ) if err != nil { return diff --git a/embeddings_test.go b/embeddings_test.go index 43897816..07f1262c 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -51,6 +51,24 @@ func TestEmbedding(t *testing.T) { t.Fatalf("Expected embedding request to contain model field") } + // test embedding request with strings and extra_body param + embeddingReqWithExtraBody := openai.EmbeddingRequest{ + Input: []string{ + "The food was delicious and the waiter", + "Other examples of embedding request", + }, + Model: model, + ExtraBody: map[string]any{ + "input_type": "query", + "truncate": "NONE", + }, + } + marshaled, err = json.Marshal(embeddingReqWithExtraBody) + checks.NoError(t, err, "Could not marshal embedding request") + if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) { + t.Fatalf("Expected embedding request to contain model field") + } + // test embedding request with strings embeddingReqStrings := openai.EmbeddingRequestStrings{ Input: []string{ @@ -124,7 +142,33 @@ func TestEmbeddingEndpoint(t *testing.T) { t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) } - // test create embeddings with strings (simple embedding request) + // test create embeddings with strings (ExtraBody in request) + res, err = client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + ExtraBody: map[string]any{ + "input_type": "query", + "truncate": "NONE", + }, + Dimensions: 1, + }, + ) + checks.NoError(t, err, "CreateEmbeddings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } + + // test create embeddings with strings (ExtraBody in request and ) + _, err = client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + Input: make(chan int), // Channels are not serializable + Model: "example_model", + }, + ) + checks.HasError(t, err, "CreateEmbeddings error") + + // test failed (Serialize JSON error) res, err = client.CreateEmbeddings( context.Background(), openai.EmbeddingRequest{